Browse Source

Add Azure AI Rerank support (#129848)

* Add Azure AI Rerank support

* address comments

* address comments

* refactor azure ai studio service

* update rerank task settings test

* add provider for rerank
Evgenii-Kazannik 3 months ago
parent
commit
d06b0c8c17
26 changed files with 2147 additions and 81 deletions
  1. 5 0
      docs/changelog/129848.yaml
  2. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  3. 1 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
  4. 13 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
  5. 9 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java
  6. 14 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java
  7. 68 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioRerankRequestManager.java
  8. 12 22
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
  9. 10 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionCreator.java
  10. 3 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionVisitor.java
  11. 74 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequest.java
  12. 59 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntity.java
  13. 95 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModel.java
  14. 48 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettings.java
  15. 123 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettings.java
  16. 149 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettings.java
  17. 128 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntity.java
  18. 308 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
  19. 124 53
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java
  20. 65 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntityTests.java
  21. 159 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestTests.java
  22. 130 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModelTests.java
  23. 83 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettingsTests.java
  24. 123 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettingsTests.java
  25. 230 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettingsTests.java
  26. 113 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntityTests.java

+ 5 - 0
docs/changelog/129848.yaml

@@ -0,0 +1,5 @@
+pr: 129848
+summary: "[ML] Add Azure AI Rerank support to the Inference Plugin"
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -341,6 +341,7 @@ public class TransportVersions {
     public static final TransportVersion LOOKUP_JOIN_CCS = def(9_120_0_00);
     public static final TransportVersion NODE_USAGE_STATS_FOR_THREAD_POOLS_IN_CLUSTER_INFO = def(9_121_0_00);
     public static final TransportVersion ESQL_CATEGORIZE_OPTIONS = def(9_122_0_00);
+    public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED = def(9_123_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 1 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

@@ -111,6 +111,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
             containsInAnyOrder(
                 List.of(
                     "alibabacloud-ai-search",
+                    "azureaistudio",
                     "cohere",
                     "elasticsearch",
                     "googlevertexai",

+ 13 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -50,6 +50,8 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.completion.Azure
 import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettings;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
 import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
 import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
 import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
@@ -306,6 +308,17 @@ public class InferenceNamedWriteablesProvider {
                 AzureAiStudioChatCompletionTaskSettings::new
             )
         );
+
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                AzureAiStudioRerankServiceSettings.NAME,
+                AzureAiStudioRerankServiceSettings::new
+            )
+        );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(TaskSettings.class, AzureAiStudioRerankTaskSettings.NAME, AzureAiStudioRerankTaskSettings::new)
+        );
     }
 
     private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

+ 9 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.azureaistudio;
 public class AzureAiStudioConstants {
     public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
     public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions";
+    public static final String RERANK_URI_PATH = "/v1/rerank";
 
     // common service settings fields
     public static final String TARGET_FIELD = "target";
@@ -22,6 +23,10 @@ public class AzureAiStudioConstants {
     public static final String DIMENSIONS_FIELD = "dimensions";
     public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
 
+    // rerank task settings fields
+    public static final String DOCUMENTS_FIELD = "documents";
+    public static final String QUERY_FIELD = "query";
+
     // embeddings task settings fields
     public static final String USER_FIELD = "user";
 
@@ -35,5 +40,9 @@ public class AzureAiStudioConstants {
     public static final Double MIN_TEMPERATURE_TOP_P = 0.0;
     public static final Double MAX_TEMPERATURE_TOP_P = 2.0;
 
+    // rerank task settings fields
+    public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
+    public static final String TOP_N_FIELD = "top_n";
+
     private AzureAiStudioConstants() {}
 }

+ 14 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java

@@ -22,6 +22,9 @@ public final class AzureAiStudioProviderCapabilities {
     // these providers have chat completion inference (all providers at the moment)
     public static final List<AzureAiStudioProvider> chatCompletionProviders = List.of(AzureAiStudioProvider.values());
 
+    // these providers have rerank inference
+    public static final List<AzureAiStudioProvider> rerankProviders = List.of(AzureAiStudioProvider.COHERE);
+
     // these providers allow token ("pay as you go") embeddings endpoints
     public static final List<AzureAiStudioProvider> tokenEmbeddingsProviders = List.of(
         AzureAiStudioProvider.OPENAI,
@@ -31,6 +34,9 @@ public final class AzureAiStudioProviderCapabilities {
     // these providers allow realtime embeddings endpoints (none at the moment)
     public static final List<AzureAiStudioProvider> realtimeEmbeddingsProviders = List.of();
 
+    // these providers allow realtime rerank endpoints (none at the moment)
+    public static final List<AzureAiStudioProvider> realtimeRerankProviders = List.of();
+
     // these providers allow token ("pay as you go") chat completion endpoints
     public static final List<AzureAiStudioProvider> tokenChatCompletionProviders = List.of(
         AzureAiStudioProvider.OPENAI,
@@ -54,6 +60,9 @@ public final class AzureAiStudioProviderCapabilities {
             case TEXT_EMBEDDING -> {
                 return embeddingProviders.contains(provider);
             }
+            case RERANK -> {
+                return rerankProviders.contains(provider);
+            }
             default -> {
                 return false;
             }
@@ -76,6 +85,11 @@ public final class AzureAiStudioProviderCapabilities {
                     ? tokenEmbeddingsProviders.contains(provider)
                     : realtimeEmbeddingsProviders.contains(provider);
             }
+            case RERANK -> {
+                return (endpointType == AzureAiStudioEndpointType.TOKEN)
+                    ? rerankProviders.contains(provider)
+                    : realtimeRerankProviders.contains(provider);
+            }
             default -> {
                 return false;
             }

+ 68 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioRerankRequestManager.java

@@ -0,0 +1,68 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
+import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
+import org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRerankRequest;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.response.AzureAiStudioRerankResponseEntity;
+import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
+
+import java.util.function.Supplier;
+
+public class AzureAiStudioRerankRequestManager extends AzureAiStudioRequestManager {
+    private static final Logger logger = LogManager.getLogger(AzureAiStudioRerankRequestManager.class);
+
+    private static final ResponseHandler HANDLER = createRerankHandler();
+
+    private final AzureAiStudioRerankModel model;
+
+    public AzureAiStudioRerankRequestManager(AzureAiStudioRerankModel model, ThreadPool threadPool) {
+        super(threadPool, model);
+        this.model = model;
+    }
+
+    @Override
+    public void execute(
+        InferenceInputs inferenceInputs,
+        RequestSender requestSender,
+        Supplier<Boolean> hasRequestRerankFunction,
+        ActionListener<InferenceServiceResults> listener
+    ) {
+        var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
+        AzureAiStudioRerankRequest request = new AzureAiStudioRerankRequest(
+            model,
+            rerankInput.getQuery(),
+            rerankInput.getChunks(),
+            rerankInput.getReturnDocuments(),
+            rerankInput.getTopN()
+        );
+
+        execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestRerankFunction, listener));
+    }
+
+    private static ResponseHandler createRerankHandler() {
+        // This currently covers response handling for Azure AI Studio
+        return new AzureMistralOpenAiExternalResponseHandler(
+            "azure ai studio rerank",
+            new AzureAiStudioRerankResponseEntity(),
+            ErrorMessageResponseEntity::fromResponse,
+            true
+        );
+    }
+}

+ 12 - 22
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

@@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.completion.Azure
 import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
 import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
 
@@ -71,10 +72,10 @@ import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFie
 
 public class AzureAiStudioService extends SenderService {
 
-    static final String NAME = "azureaistudio";
+    public static final String NAME = "azureaistudio";
 
     private static final String SERVICE_NAME = "Azure AI Studio";
-    private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
+    private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.RERANK);
 
     private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
         InputType.INGEST,
@@ -270,8 +271,9 @@ public class AzureAiStudioService extends SenderService {
         ConfigurationParseContext context
     ) {
 
-        if (taskType == TaskType.TEXT_EMBEDDING) {
-            var embeddingsModel = new AzureAiStudioEmbeddingsModel(
+        AzureAiStudioModel model;
+        switch (taskType) {
+            case TEXT_EMBEDDING -> model = new AzureAiStudioEmbeddingsModel(
                 inferenceEntityId,
                 taskType,
                 NAME,
@@ -281,16 +283,7 @@ public class AzureAiStudioService extends SenderService {
                 secretSettings,
                 context
             );
-            checkProviderAndEndpointTypeForTask(
-                TaskType.TEXT_EMBEDDING,
-                embeddingsModel.getServiceSettings().provider(),
-                embeddingsModel.getServiceSettings().endpointType()
-            );
-            return embeddingsModel;
-        }
-
-        if (taskType == TaskType.COMPLETION) {
-            var completionModel = new AzureAiStudioChatCompletionModel(
+            case COMPLETION -> model = new AzureAiStudioChatCompletionModel(
                 inferenceEntityId,
                 taskType,
                 NAME,
@@ -299,15 +292,12 @@ public class AzureAiStudioService extends SenderService {
                 secretSettings,
                 context
             );
-            checkProviderAndEndpointTypeForTask(
-                TaskType.COMPLETION,
-                completionModel.getServiceSettings().provider(),
-                completionModel.getServiceSettings().endpointType()
-            );
-            return completionModel;
+            case RERANK -> model = new AzureAiStudioRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context);
+            default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
         }
-
-        throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+        final var azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings();
+        checkProviderAndEndpointTypeForTask(taskType, azureAiStudioServiceSettings.provider(), azureAiStudioServiceSettings.endpointType());
+        return model;
     }
 
     private AzureAiStudioModel createModelFromPersistent(

+ 10 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionCreator.java

@@ -13,8 +13,10 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioChatCompletionRequestManager;
 import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEmbeddingsRequestManager;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioRerankRequestManager;
 import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
 
 import java.util.Map;
 import java.util.Objects;
@@ -49,4 +51,12 @@ public class AzureAiStudioActionCreator implements AzureAiStudioActionVisitor {
         var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings");
         return new SenderExecutableAction(sender, requestManager, errorMessage);
     }
+
+    @Override
+    public ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings) {
+        var overriddenModel = AzureAiStudioRerankModel.of(rerankModel, taskSettings);
+        var requestManager = new AzureAiStudioRerankRequestManager(overriddenModel, serviceComponents.threadPool());
+        var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio rerank");
+        return new SenderExecutableAction(sender, requestManager, errorMessage);
+    }
 }

+ 3 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionVisitor.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.azureaistudio.action;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
 
 import java.util.Map;
 
@@ -17,4 +18,6 @@ public interface AzureAiStudioActionVisitor {
     ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);
 
     ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings);
+
+    ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings);
 }

+ 74 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequest.java

@@ -0,0 +1,74 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.request;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
+
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Objects;
+
+public class AzureAiStudioRerankRequest extends AzureAiStudioRequest {
+    private final String query;
+    private final List<String> input;
+    private final Boolean returnDocuments;
+    private final Integer topN;
+    private final AzureAiStudioRerankModel rerankModel;
+
+    public AzureAiStudioRerankRequest(
+        AzureAiStudioRerankModel model,
+        String query,
+        List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN
+    ) {
+        super(model);
+        this.rerankModel = Objects.requireNonNull(model);
+        this.query = query;
+        this.input = Objects.requireNonNull(input);
+        this.returnDocuments = returnDocuments;
+        this.topN = topN;
+    }
+
+    @Override
+    public HttpRequest createHttpRequest() {
+        HttpPost httpPost = new HttpPost(this.uri);
+
+        ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(createRequestEntity()).getBytes(StandardCharsets.UTF_8));
+        httpPost.setEntity(byteEntity);
+
+        httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
+        setAuthHeader(httpPost, rerankModel);
+
+        return new HttpRequest(httpPost, getInferenceEntityId());
+    }
+
+    @Override
+    public Request truncate() {
+        // Not applicable for rerank, only used in text embedding requests
+        return this;
+    }
+
+    @Override
+    public boolean[] getTruncationInfo() {
+        // Not applicable for rerank, only used in text embedding requests
+        return null;
+    }
+
+    private AzureAiStudioRerankRequestEntity createRequestEntity() {
+        return new AzureAiStudioRerankRequestEntity(query, input, returnDocuments, topN, rerankModel.getTaskSettings());
+    }
+}

+ 59 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntity.java

@@ -0,0 +1,59 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.request;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DOCUMENTS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.QUERY_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
+
+public record AzureAiStudioRerankRequestEntity(
+    String query,
+    List<String> input,
+    @Nullable Boolean returnDocuments,
+    @Nullable Integer topN,
+    AzureAiStudioRerankTaskSettings taskSettings
+) implements ToXContentObject {
+
+    public AzureAiStudioRerankRequestEntity {
+        Objects.requireNonNull(query);
+        Objects.requireNonNull(input);
+        Objects.requireNonNull(taskSettings);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        builder.field(DOCUMENTS_FIELD, input);
+        builder.field(QUERY_FIELD, query);
+
+        if (returnDocuments != null) {
+            builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments);
+        } else if (taskSettings.returnDocuments() != null) {
+            builder.field(RETURN_DOCUMENTS_FIELD, taskSettings.returnDocuments());
+        }
+
+        if (topN != null) {
+            builder.field(TOP_N_FIELD, topN);
+        } else if (taskSettings.topN() != null) {
+            builder.field(TOP_N_FIELD, taskSettings.topN());
+        }
+        builder.endObject();
+        return builder;
+    }
+}

+ 95 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModel.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
+import org.elasticsearch.xpack.inference.services.azureaistudio.action.AzureAiStudioActionVisitor;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RERANK_URI_PATH;
+
+public class AzureAiStudioRerankModel extends AzureAiStudioModel {
+
+    public static AzureAiStudioRerankModel of(AzureAiStudioRerankModel model, Map<String, Object> taskSettings) {
+        if (taskSettings == null || taskSettings.isEmpty()) {
+            return model;
+        }
+
+        final var requestTaskSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(taskSettings);
+        final var taskSettingToUse = AzureAiStudioRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings);
+
+        return new AzureAiStudioRerankModel(model, taskSettingToUse);
+    }
+
+    public AzureAiStudioRerankModel(
+        String inferenceEntityId,
+        AzureAiStudioRerankServiceSettings serviceSettings,
+        AzureAiStudioRerankTaskSettings taskSettings,
+        DefaultSecretSettings secrets
+    ) {
+        super(
+            new ModelConfigurations(inferenceEntityId, TaskType.RERANK, AzureAiStudioService.NAME, serviceSettings, taskSettings),
+            new ModelSecrets(secrets)
+        );
+    }
+
+    public AzureAiStudioRerankModel(
+        String inferenceEntityId,
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        @Nullable Map<String, Object> secrets,
+        ConfigurationParseContext context
+    ) {
+        this(
+            inferenceEntityId,
+            AzureAiStudioRerankServiceSettings.fromMap(serviceSettings, context),
+            AzureAiStudioRerankTaskSettings.fromMap(taskSettings),
+            DefaultSecretSettings.fromMap(secrets)
+        );
+    }
+
+    public AzureAiStudioRerankModel(AzureAiStudioRerankModel model, AzureAiStudioRerankTaskSettings taskSettings) {
+        super(model, taskSettings, model.getServiceSettings().rateLimitSettings());
+    }
+
+    @Override
+    public AzureAiStudioRerankServiceSettings getServiceSettings() {
+        return (AzureAiStudioRerankServiceSettings) super.getServiceSettings();
+    }
+
+    @Override
+    public AzureAiStudioRerankTaskSettings getTaskSettings() {
+        return (AzureAiStudioRerankTaskSettings) super.getTaskSettings();
+    }
+
+    @Override
+    public DefaultSecretSettings getSecretSettings() {
+        return super.getSecretSettings();
+    }
+
+    @Override
+    protected URI getEndpointUri() throws URISyntaxException {
+        return new URI(this.target + RERANK_URI_PATH);
+    }
+
+    @Override
+    public ExecutableAction accept(AzureAiStudioActionVisitor creator, Map<String, Object> taskSettings) {
+        return creator.create(this, taskSettings);
+    }
+}

+ 48 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettings.java

@@ -0,0 +1,48 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
+
+public record AzureAiStudioRerankRequestTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
+
+    public static final AzureAiStudioRerankRequestTaskSettings EMPTY_SETTINGS = new AzureAiStudioRerankRequestTaskSettings(null, null);
+
+    /**
+     * Extracts the task settings from a map. All settings are considered optional and the absence of a setting
+     * does not throw an error.
+     *
+     * @param map the settings received from a request
+     * @return a {@link AzureAiStudioRerankRequestTaskSettings}
+     */
+    public static AzureAiStudioRerankRequestTaskSettings fromMap(Map<String, Object> map) {
+        if (map.isEmpty()) {
+            return AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS;
+        }
+
+        final var validationException = new ValidationException();
+
+        final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException);
+        final var topN = extractOptionalPositiveInteger(map, TOP_N_FIELD, ModelConfigurations.TASK_SETTINGS, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureAiStudioRerankRequestTaskSettings(returnDocuments, topN);
+    }
+}

+ 123 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettings.java

@@ -0,0 +1,123 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+public class AzureAiStudioRerankServiceSettings extends AzureAiStudioServiceSettings {
+    public static final String NAME = "azure_ai_studio_rerank_service_settings";
+
+    public static AzureAiStudioRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
+        final var validationException = new ValidationException();
+
+        final var settings = rerankSettingsFromMap(map, validationException, context);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureAiStudioRerankServiceSettings(settings);
+    }
+
+    private static AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields rerankSettingsFromMap(
+        Map<String, Object> map,
+        ValidationException validationException,
+        ConfigurationParseContext context
+    ) {
+        final var baseSettings = AzureAiStudioServiceSettings.fromMap(map, validationException, context);
+        return new AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields(baseSettings);
+    }
+
+    private record AzureAiStudioRerankCommonFields(BaseAzureAiStudioCommonFields baseCommonFields) {}
+
+    public AzureAiStudioRerankServiceSettings(
+        String target,
+        AzureAiStudioProvider provider,
+        AzureAiStudioEndpointType endpointType,
+        @Nullable RateLimitSettings rateLimitSettings
+    ) {
+        super(target, provider, endpointType, rateLimitSettings);
+    }
+
+    public AzureAiStudioRerankServiceSettings(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    private AzureAiStudioRerankServiceSettings(AzureAiStudioRerankServiceSettings.AzureAiStudioRerankCommonFields fields) {
+        this(
+            fields.baseCommonFields.target(),
+            fields.baseCommonFields.provider(),
+            fields.baseCommonFields.endpointType(),
+            fields.baseCommonFields.rateLimitSettings()
+        );
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        super.addXContentFields(builder, params);
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        super.addExposedXContentFields(builder, params);
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        AzureAiStudioRerankServiceSettings that = (AzureAiStudioRerankServiceSettings) o;
+
+        return Objects.equals(target, that.target)
+            && Objects.equals(provider, that.provider)
+            && Objects.equals(endpointType, that.endpointType)
+            && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(target, provider, endpointType, rateLimitSettings);
+    }
+}

+ 149 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettings.java

@@ -0,0 +1,149 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
+
+/**
+ * Defines the rerank task settings for the AzureAiStudio service.
+ */
+public class AzureAiStudioRerankTaskSettings implements TaskSettings {
+    public static final String NAME = "azure_ai_studio_rerank_task_settings";
+
+    public static AzureAiStudioRerankTaskSettings fromMap(Map<String, Object> map) {
+        final var validationException = new ValidationException();
+
+        final var returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS_FIELD, validationException);
+        final var topN = extractOptionalPositiveInteger(map, TOP_N_FIELD, ModelConfigurations.TASK_SETTINGS, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureAiStudioRerankTaskSettings(returnDocuments, topN);
+    }
+
+    /**
+     * Creates a new {@link AzureAiStudioRerankTaskSettings} object by overriding the values in originalSettings with the ones
+     * passed in via requestSettings if the fields are not null.
+     * @param originalSettings the original {@link AzureAiStudioRerankTaskSettings} from the inference entity configuration from storage
+     * @param requestSettings the {@link AzureAiStudioRerankTaskSettings} from the request
+     * @return a new {@link AzureAiStudioRerankTaskSettings}
+     */
+    public static AzureAiStudioRerankTaskSettings of(
+        AzureAiStudioRerankTaskSettings originalSettings,
+        AzureAiStudioRerankRequestTaskSettings requestSettings
+    ) {
+
+        final var returnDocuments = requestSettings.returnDocuments() == null
+            ? originalSettings.returnDocuments()
+            : requestSettings.returnDocuments();
+        final var topN = requestSettings.topN() == null ? originalSettings.topN() : requestSettings.topN();
+
+        return new AzureAiStudioRerankTaskSettings(returnDocuments, topN);
+    }
+
+    public AzureAiStudioRerankTaskSettings(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
+        this.returnDocuments = returnDocuments;
+        this.topN = topN;
+    }
+
+    public AzureAiStudioRerankTaskSettings(StreamInput in) throws IOException {
+        this.returnDocuments = in.readOptionalBoolean();
+        this.topN = in.readOptionalVInt();
+    }
+
+    private final Boolean returnDocuments;
+    private final Integer topN;
+
+    public Boolean returnDocuments() {
+        return returnDocuments;
+    }
+
+    public Integer topN() {
+        return topN;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO_RERANK_ADDED;
+    }
+
+    @Override
+    public boolean isEmpty() {
+        return returnDocuments == null && topN == null;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalBoolean(returnDocuments);
+        out.writeOptionalVInt(topN);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        if (returnDocuments != null) {
+            builder.field(RETURN_DOCUMENTS_FIELD, returnDocuments);
+        }
+        if (topN != null) {
+            builder.field(TOP_N_FIELD, topN);
+        }
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String toString() {
+        return "AzureAiStudioRerankTaskSettings{" + ", returnDocuments=" + returnDocuments + ", topN=" + topN + '}';
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        AzureAiStudioRerankTaskSettings that = (AzureAiStudioRerankTaskSettings) o;
+        return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topN, that.topN);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(returnDocuments, topN);
+    }
+
+    @Override
+    public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
+        AzureAiStudioRerankRequestTaskSettings requestSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(newSettings));
+        return of(this, requestSettings);
+    }
+}

+ 128 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntity.java

@@ -0,0 +1,128 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.response;
+
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+public class AzureAiStudioRerankResponseEntity extends BaseResponseEntity {
+    /**
+     * Parses the AzureAiStudio Search rerank json response.
+     * For a request like:
+     *
+     * <pre>
+     * <code>
+     * {
+     *     "model": "rerank-v3.5",
+     *     "query": "What is the capital of the United States?",
+     *     "top_n": 2,
+     *     "documents": ["Carson City is the capital city of the American state of Nevada.",
+     *                   "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean."]
+     * }
+     * </code>
+     * </pre>
+     *
+     * The response would look like:
+     *
+     * <pre>
+     * <code>
+     * {
+     *     "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
+     *     "results": [
+     *         {
+     *             "document": {
+     *                 "text": "Carson City is the capital city of the American state of Nevada."
+     *             },
+     *             "index": 0,
+     *             "relevance_score": 0.1728413
+     *         },
+     *         {
+     *             "document": {
+     *                 "text": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean."
+     *             },
+     *             "index": 1,
+     *             "relevance_score": 0.031005697
+     *         }
+     *     ],
+     *     "meta": {
+     *         "api_version": {
+     *             "version": "1"
+     *         },
+     *         "billed_units": {
+     *             "search_units": 1
+     *         }
+     *     }
+     * }
+     * </code>
+     * </pre>
+     */
+    @Override
+    protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
+        final var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+        try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+            var rerankResult = RerankResult.PARSER.apply(jsonParser, null);
+            return new RankedDocsResults(rerankResult.entries.stream().map(RerankResultEntry::toRankedDoc).toList());
+        }
+    }
+
+    record RerankResult(List<RerankResultEntry> entries) {
+        @SuppressWarnings("unchecked")
+        public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
+            RerankResult.class.getSimpleName(),
+            true,
+            args -> new RerankResult((List<RerankResultEntry>) args[0])
+        );
+        static {
+            PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
+        }
+    }
+
+    record RerankResultEntry(Float relevanceScore, Integer index, @Nullable ObjectParser document) {
+
+        public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
+            RerankResultEntry.class.getSimpleName(),
+            args -> new RerankResultEntry((Float) args[0], (Integer) args[1], (ObjectParser) args[2])
+        );
+        static {
+            PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
+            PARSER.declareInt(constructorArg(), new ParseField("index"));
+            PARSER.declareObject(optionalConstructorArg(), ObjectParser.PARSER::apply, new ParseField("document"));
+        }
+        public RankedDocsResults.RankedDoc toRankedDoc() {
+            return new RankedDocsResults.RankedDoc(index, relevanceScore, document == null ? null : document.text);
+        }
+    }
+
+    record ObjectParser(String text) {
+        public static final ConstructingObjectParser<ObjectParser, Void> PARSER = new ConstructingObjectParser<>(
+            ObjectParser.class.getSimpleName(),
+            args -> new AzureAiStudioRerankResponseEntity.ObjectParser((String) args[0])
+        );
+        static {
+            PARSER.declareString(optionalConstructorArg(), new ParseField("text"));
+        }
+    }
+}

+ 308 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

@@ -38,6 +38,7 @@ import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -54,6 +55,10 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.Azure
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
 import org.hamcrest.CoreMatchers;
 import org.hamcrest.Matchers;
@@ -219,6 +224,33 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testParseRequestConfig_CreatesAnAzureAiStudioRerankModel() throws IOException {
+        try (var service = createService()) {
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
+                assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
+
+                var rerankModel = (AzureAiStudioRerankModel) model;
+                assertThat(rerankModel.getServiceSettings().target(), is("http://target.local"));
+                assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
+                assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+                assertThat(rerankModel.getSecretSettings().apiKey().toString(), is("secret"));
+                assertNull(rerankModel.getTaskSettings().returnDocuments());
+                assertNull(rerankModel.getTaskSettings().topN());
+            }, exception -> fail("Unexpected exception: " + exception));
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.RERANK,
+                getRequestConfigMap(
+                    getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
+                    getRerankTaskSettingsMap(null, null),
+                    getSecretSettingsMap("secret")
+                ),
+                modelVerificationListener
+            );
+        }
+    }
+
     public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
         try (var service = createService()) {
             ActionListener<Model> modelVerificationListener = ActionListener.wrap(
@@ -441,6 +473,80 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException {
+        try (var service = createService()) {
+            var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
+            serviceSettings.put("extra_key", "value");
+
+            var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
+
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(
+                model -> fail("Expected exception, but got model: " + model),
+                exception -> {
+                    assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+                    assertThat(
+                        exception.getMessage(),
+                        is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+                    );
+                }
+            );
+
+            service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException {
+        try (var service = createService()) {
+            var taskSettings = getRerankTaskSettingsMap(null, null);
+            taskSettings.put("extra_key", "value");
+
+            var config = getRequestConfigMap(
+                getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
+                taskSettings,
+                getSecretSettingsMap("secret")
+            );
+
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(
+                model -> fail("Expected exception, but got model: " + model),
+                exception -> {
+                    assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+                    assertThat(
+                        exception.getMessage(),
+                        is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+                    );
+                }
+            );
+
+            service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException {
+        try (var service = createService()) {
+            var secretSettings = getSecretSettingsMap("secret");
+            secretSettings.put("extra_key", "value");
+
+            var config = getRequestConfigMap(
+                getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
+                getRerankTaskSettingsMap(null, null),
+                secretSettings
+            );
+
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(
+                model -> fail("Expected exception, but got model: " + model),
+                exception -> {
+                    assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+                    assertThat(
+                        exception.getMessage(),
+                        is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+                    );
+                }
+            );
+
+            service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
+        }
+    }
+
     public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForEmbeddings() throws IOException {
         try (var service = createService()) {
             var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "databricks", "token", null, null, null, null);
@@ -505,6 +611,45 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForRerank() throws IOException {
+        try (var service = createService()) {
+            var serviceSettings = getRerankServiceSettingsMap("http://target.local", "databricks", "token");
+
+            var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
+
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(
+                model -> fail("Expected exception, but got model: " + model),
+                exception -> {
+                    assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+                    assertThat(exception.getMessage(), is("The [rerank] task type for provider [databricks] is not available"));
+                }
+            );
+
+            service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForRerankProvider() throws IOException {
+        try (var service = createService()) {
+            var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "realtime");
+
+            var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
+
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(
+                model -> fail("Expected exception, but got model: " + model),
+                exception -> {
+                    assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+                    assertThat(
+                        exception.getMessage(),
+                        is("The [realtime] endpoint type with [rerank] task type for provider [cohere] is not available")
+                    );
+                }
+            );
+
+            service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
+        }
+    }
+
     public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() throws IOException {
         try (var service = createService()) {
             var config = getPersistedConfigMap(
@@ -603,6 +748,27 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testParsePersistedConfig_CreatesAnAzureAiStudioRerankModel() throws IOException {
+        try (var service = createService()) {
+            var config = getPersistedConfigMap(
+                getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
+                getRerankTaskSettingsMap(true, 2),
+                getSecretSettingsMap("secret")
+            );
+
+            var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
+
+            assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
+
+            var chatCompletionModel = (AzureAiStudioRerankModel) model;
+            assertThat(chatCompletionModel.getServiceSettings().target(), is("http://target.local"));
+            assertThat(chatCompletionModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
+            assertThat(chatCompletionModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+            assertThat(chatCompletionModel.getTaskSettings().returnDocuments(), is(true));
+            assertThat(chatCompletionModel.getTaskSettings().topN(), is(2));
+        }
+    }
+
     public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException {
         try (var service = createService()) {
             ActionListener<Model> modelVerificationListener = ActionListener.wrap(
@@ -747,6 +913,48 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException {
+        try (var service = createService()) {
+            var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
+            serviceSettings.put("extra_key", "value");
+            var taskSettings = getRerankTaskSettingsMap(true, 2);
+            var secretSettings = getSecretSettingsMap("secret");
+            var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+            var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
+
+            assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
+        }
+    }
+
+    public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException {
+        try (var service = createService()) {
+            var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
+            var taskSettings = getRerankTaskSettingsMap(true, 2);
+            taskSettings.put("extra_key", "value");
+            var secretSettings = getSecretSettingsMap("secret");
+            var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+            var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
+
+            assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
+        }
+    }
+
+    public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException {
+        try (var service = createService()) {
+            var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
+            var taskSettings = getRerankTaskSettingsMap(true, 2);
+            var secretSettings = getSecretSettingsMap("secret");
+            secretSettings.put("extra_key", "value");
+            var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+            var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
+
+            assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
+        }
+    }
+
     public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException {
         try (var service = createService()) {
             var config = getPersistedConfigMap(
@@ -842,6 +1050,27 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testParsePersistedConfig_WithoutSecretsCreatesRerankModel() throws IOException {
+        try (var service = createService()) {
+            var config = getPersistedConfigMap(
+                getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
+                getRerankTaskSettingsMap(true, 2),
+                Map.of()
+            );
+
+            var model = service.parsePersistedConfig("id", TaskType.RERANK, config.config());
+
+            assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
+
+            var rerankModel = (AzureAiStudioRerankModel) model;
+            assertThat(rerankModel.getServiceSettings().target(), is("http://target.local"));
+            assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
+            assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+            assertThat(rerankModel.getTaskSettings().returnDocuments(), is(true));
+            assertThat(rerankModel.getTaskSettings().topN(), is(2));
+        }
+    }
+
     public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
         try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
@@ -1184,6 +1413,47 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testInfer_WithRerankModel() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson));
+
+            var model = AzureAiStudioRerankModelTests.createModel(
+                "id",
+                getUrl(webServer),
+                AzureAiStudioProvider.COHERE,
+                AzureAiStudioEndpointType.TOKEN,
+                "apikey"
+            );
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            service.infer(
+                model,
+                "query",
+                false,
+                2,
+                List.of("abc"),
+                false,
+                new HashMap<>(),
+                InputType.INGEST,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var result = listener.actionGet(TIMEOUT);
+            assertThat(result, CoreMatchers.instanceOf(RankedDocsResults.class));
+
+            var rankedDocsResults = (RankedDocsResults) result;
+            var rankedDocs = rankedDocsResults.getRankedDocs();
+            assertThat(rankedDocs.size(), is(2));
+            assertThat(rankedDocs.get(0).relevanceScore(), is(0.1111111F));
+            assertThat(rankedDocs.get(0).index(), is(0));
+            assertThat(rankedDocs.get(1).relevanceScore(), is(0.2222222F));
+            assertThat(rankedDocs.get(1).index(), is(1));
+        }
+    }
+
     public void testInfer_UnauthorisedResponse() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 
@@ -1320,7 +1590,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                     {
                         "service": "azureaistudio",
                         "name": "Azure AI Studio",
-                        "task_types": ["text_embedding", "completion"],
+                        "task_types": ["text_embedding", "rerank", "completion"],
                         "configurations": {
                             "dimensions": {
                                 "description": "The number of dimensions the resulting embeddings should have. For more information refer to https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-embeddings.",
@@ -1338,7 +1608,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                                 "sensitive": false,
                                 "updatable": false,
                                 "type": "str",
-                                "supported_task_types": ["text_embedding", "completion"]
+                                "supported_task_types": ["text_embedding", "rerank", "completion"]
                             },
                             "provider": {
                                 "description": "The model provider for your deployment.",
@@ -1347,7 +1617,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                                 "sensitive": false,
                                 "updatable": false,
                                 "type": "str",
-                                "supported_task_types": ["text_embedding", "completion"]
+                                "supported_task_types": ["text_embedding", "rerank", "completion"]
                             },
                             "api_key": {
                                 "description": "API Key for the provider you're connecting to.",
@@ -1356,7 +1626,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                                 "sensitive": true,
                                 "updatable": true,
                                 "type": "str",
-                                "supported_task_types": ["text_embedding", "completion"]
+                                "supported_task_types": ["text_embedding", "rerank", "completion"]
                             },
                             "rate_limit.requests_per_minute": {
                                 "description": "Minimize the number of rate limit errors.",
@@ -1365,7 +1635,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                                 "sensitive": false,
                                 "updatable": false,
                                 "type": "int",
-                                "supported_task_types": ["text_embedding", "completion"]
+                                "supported_task_types": ["text_embedding", "rerank", "completion"]
                             },
                             "target": {
                                 "description": "The target URL of your Azure AI Studio model deployment.",
@@ -1374,7 +1644,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                                 "sensitive": false,
                                 "updatable": false,
                                 "type": "str",
-                                "supported_task_types": ["text_embedding", "completion"]
+                                "supported_task_types": ["text_embedding", "rerank", "completion"]
                             }
                         }
                     }
@@ -1462,6 +1732,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         return AzureAiStudioChatCompletionServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType);
     }
 
+    private static HashMap<String, Object> getRerankServiceSettingsMap(String target, String provider, String endpointType) {
+        return AzureAiStudioRerankServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType);
+    }
+
     public static Map<String, Object> getChatCompletionTaskSettingsMap(
         @Nullable Double temperature,
         @Nullable Double topP,
@@ -1471,6 +1745,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         return AzureAiStudioChatCompletionTaskSettingsTests.getTaskSettingsMap(temperature, topP, doSample, maxNewTokens);
     }
 
+    public static Map<String, Object> getRerankTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
+        return AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(returnDocuments, topN);
+    }
+
     private static Map<String, Object> getSecretSettingsMap(String apiKey) {
         return new HashMap<>(Map.of(API_KEY_FIELD, apiKey));
     }
@@ -1520,4 +1798,28 @@ public class AzureAiStudioServiceTests extends ESTestCase {
             }
         }
         """;
+
+    private static final String testRerankTokenResponseJson = """
+        {
+            "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
+            "results": [
+                {
+                    "index": 0,
+                    "relevance_score": 0.1111111
+                },
+                {
+                    "index": 1,
+                    "relevance_score": 0.2222222
+                }
+            ],
+            "meta": {
+                "api_version": {
+                    "version": "1"
+                },
+                "billed_units": {
+                    "search_units": 1
+                }
+            }
+        }
+        """;
 }

+ 124 - 53
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java

@@ -20,6 +20,7 @@ import org.elasticsearch.test.http.MockWebServer;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
 import org.elasticsearch.xpack.inference.InputTypeTests;
 import org.elasticsearch.xpack.inference.common.TruncatorTests;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -27,6 +28,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInpu
 import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
@@ -34,6 +36,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEnd
 import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
 import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
 import org.junit.After;
 import org.junit.Before;
 
@@ -78,31 +81,20 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
     }
 
     public void testEmbeddingsRequestAction() throws IOException {
-        var senderFactory = new HttpRequestSender.Factory(
+        final var senderFactory = new HttpRequestSender.Factory(
             ServiceComponentsTests.createWithEmptySettings(threadPool),
             clientManager,
             mockClusterServiceEmpty()
         );
 
-        var timeoutSettings = buildSettingsWithRetryFields(
-            TimeValue.timeValueMillis(1),
-            TimeValue.timeValueMinutes(1),
-            TimeValue.timeValueSeconds(0)
-        );
-
-        var serviceComponents = new ServiceComponents(
-            threadPool,
-            mock(ThrottlerManager.class),
-            timeoutSettings,
-            TruncatorTests.createTruncator()
-        );
+        final var serviceComponents = getServiceComponents();
 
         try (var sender = createSender(senderFactory)) {
             sender.start();
 
             webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
 
-            var model = AzureAiStudioEmbeddingsModelTests.createModel(
+            final var model = AzureAiStudioEmbeddingsModelTests.createModel(
                 "id",
                 "http://will-be-replaced.local",
                 AzureAiStudioProvider.OPENAI,
@@ -111,21 +103,18 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
             );
             model.setURI(getUrl(webServer));
 
-            var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
-            var action = creator.create(model, Map.of());
-            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            var inputType = InputTypeTests.randomSearchAndIngestWithNull();
+            final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
+            final var action = creator.create(model, Map.of());
+            final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            final var inputType = InputTypeTests.randomSearchAndIngestWithNull();
             action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
-            var result = listener.actionGet(TIMEOUT);
+            final var result = listener.actionGet(TIMEOUT);
 
             assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F }))));
-            assertThat(webServer.requests(), hasSize(1));
-            assertNull(webServer.requests().get(0).getUri().getQuery());
-            assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
-            assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey"));
+            assertWebServerRequest(API_KEY_HEADER, "apikey");
 
-            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            final var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             assertThat(requestMap.size(), is(InputType.isSpecified(inputType) ? 2 : 1));
             assertThat(requestMap.get("input"), is(List.of("abc")));
             if (InputType.isSpecified(inputType)) {
@@ -136,27 +125,15 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
     }
 
     public void testChatCompletionRequestAction() throws IOException {
-        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
-
-        var timeoutSettings = buildSettingsWithRetryFields(
-            TimeValue.timeValueMillis(1),
-            TimeValue.timeValueMinutes(1),
-            TimeValue.timeValueSeconds(0)
-        );
-
-        var serviceComponents = new ServiceComponents(
-            threadPool,
-            mock(ThrottlerManager.class),
-            timeoutSettings,
-            TruncatorTests.createTruncator()
-        );
+        final var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        final var serviceComponents = getServiceComponents();
 
         try (var sender = createSender(senderFactory)) {
             sender.start();
 
             webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));
-            var webserverUrl = getUrl(webServer);
-            var model = AzureAiStudioChatCompletionModelTests.createModel(
+            final var webserverUrl = getUrl(webServer);
+            final var model = AzureAiStudioChatCompletionModelTests.createModel(
                 "id",
                 "http://will-be-replaced.local",
                 AzureAiStudioProvider.COHERE,
@@ -165,30 +142,101 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
             );
             model.setURI(webserverUrl);
 
-            var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
-            var action = creator.create(model, Map.of());
+            final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
+            final var action = creator.create(model, Map.of());
 
-            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
 
-            var result = listener.actionGet(TIMEOUT);
+            final var result = listener.actionGet(TIMEOUT);
 
             assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string"))));
-            assertThat(webServer.requests(), hasSize(1));
-
-            MockRequest request = webServer.requests().get(0);
 
-            assertNull(request.getUri().getQuery());
-            assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
-            assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("apikey"));
+            assertWebServerRequest(HttpHeaders.AUTHORIZATION, "apikey");
 
-            var requestMap = entityAsMap(request.getBody());
+            final MockRequest request = webServer.requests().get(0);
+            final var requestMap = entityAsMap(request.getBody());
             assertThat(requestMap.size(), is(1));
             assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
         }
     }
 
-    private static String testEmbeddingsTokenResponseJson = """
+    public void testRerankRequestAction() throws IOException {
+        final var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        final var serviceComponents = getServiceComponents();
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson));
+            final var webserverUrl = getUrl(webServer);
+            final var model = AzureAiStudioRerankModelTests.createModel(
+                "id",
+                "http://will-be-replaced.local",
+                AzureAiStudioProvider.COHERE,
+                AzureAiStudioEndpointType.TOKEN,
+                "apikey"
+            );
+            model.setURI(webserverUrl);
+
+            final var topN = 2;
+            final var returnDocuments = false;
+            final var query = "query";
+            final var documents = List.of("document 1", "document 2", "document 3");
+
+            final var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
+            final var action = creator.create(model, Map.of());
+
+            final PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(
+                new QueryAndDocsInputs(query, documents, returnDocuments, topN, false),
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            final var result = listener.actionGet(TIMEOUT);
+
+            assertThat(
+                result.asMap(),
+                equalTo(
+                    RankedDocsResultsTests.buildExpectationRerank(
+                        List.of(
+                            new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.1111111f)),
+                            new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 0.2222222f))
+                        )
+                    )
+                )
+            );
+
+            assertWebServerRequest(HttpHeaders.AUTHORIZATION, "apikey");
+
+            final var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+
+            assertThat(requestMap.size(), is(4));
+            assertThat(requestMap.get("documents"), is(documents));
+            assertThat(requestMap.get("query"), is(query));
+            assertThat(requestMap.get("top_n"), is(topN));
+            assertThat(requestMap.get("return_documents"), is(returnDocuments));
+        }
+    }
+
+    private void assertWebServerRequest(String authorization, String authorizationHeaderValue) {
+        assertThat(webServer.requests(), hasSize(1));
+        assertNull(webServer.requests().get(0).getUri().getQuery());
+        assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+        assertThat(webServer.requests().get(0).getHeader(authorization), equalTo(authorizationHeaderValue));
+    }
+
+    private ServiceComponents getServiceComponents() {
+        final var timeoutSettings = buildSettingsWithRetryFields(
+            TimeValue.timeValueMillis(1),
+            TimeValue.timeValueMinutes(1),
+            TimeValue.timeValueSeconds(0)
+        );
+        return new ServiceComponents(threadPool, mock(ThrottlerManager.class), timeoutSettings, TruncatorTests.createTruncator());
+    }
+
+    private final String testEmbeddingsTokenResponseJson = """
         {
           "object": "list",
           "data": [
@@ -209,7 +257,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
         }
         """;
 
-    private static String testCompletionTokenResponseJson = """
+    private final String testCompletionTokenResponseJson = """
         {
             "choices": [
                 {
@@ -233,4 +281,27 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
             }
         }""";
 
+    private final String testRerankTokenResponseJson = """
+        {
+            "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
+            "results": [
+                {
+                    "index": 0,
+                    "relevance_score": 0.1111111
+                },
+                {
+                    "index": 1,
+                    "relevance_score": 0.2222222
+                }
+            ],
+            "meta": {
+                "api_version": {
+                    "version": "1"
+                },
+                "billed_units": {
+                    "search_units": 1
+                }
+            }
+        }
+        """;
 }

+ 65 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestEntityTests.java

@@ -0,0 +1,65 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.request;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace;
+
+public class AzureAiStudioRerankRequestEntityTests extends ESTestCase {
+    private static final String INPUT = "texts";
+    private static final String QUERY = "query";
+    private static final Boolean RETURN_DOCUMENTS = false;
+    private static final Integer TOP_N = 8;
+
+    public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
+        final var entity = new AzureAiStudioRerankRequestEntity(
+            QUERY,
+            List.of(INPUT),
+            Boolean.TRUE,
+            TOP_N,
+            new AzureAiStudioRerankTaskSettings(RETURN_DOCUMENTS, TOP_N)
+        );
+
+        final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
+        final String xContentResult = Strings.toString(builder);
+        final String expected = """
+            {"documents":["texts"],
+            "query":"query",
+            "return_documents":true,
+            "top_n":8}""";
+        assertEquals(stripWhitespace(expected), xContentResult);
+    }
+
+    public void testXContent_WritesMinimalFields() throws IOException {
+        final var entity = new AzureAiStudioRerankRequestEntity(
+            QUERY,
+            List.of(INPUT),
+            null,
+            null,
+            new AzureAiStudioRerankTaskSettings(null, null)
+        );
+
+        final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
+        final String xContentResult = Strings.toString(builder);
+        final String expected = """
+            {"documents":["texts"],"query":"query"}""";
+        assertEquals(stripWhitespace(expected), xContentResult);
+    }
+}

+ 159 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/request/AzureAiStudioRerankRequestTests.java

@@ -0,0 +1,159 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.request;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class AzureAiStudioRerankRequestTests extends ESTestCase {
+    private static final String TARGET_URI = "http://testtarget.local";
+    private static final String INPUT = "documents";
+    private static final String QUERY = "query";
+    private static final Integer TOP_N = 2;
+
+    public void testCreateRequest_WithCohereProviderTokenEndpoint_NoParams() throws IOException {
+        final var input = randomAlphaOfLength(3);
+        final var query = randomAlphaOfLength(3);
+        final var apikey = randomAlphaOfLength(3);
+        final var request = createRequest(TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, apikey, query, input);
+        final var httpPost = getHttpPost(request, apikey);
+        final var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(2));
+        assertThat(requestMap.get(QUERY), is(query));
+        assertThat(requestMap.get(INPUT), is(List.of(input)));
+    }
+
+    public void testCreateRequest_WithCohereProviderTokenEndpoint_WithTopNParam() throws IOException {
+        final var input = randomAlphaOfLength(3);
+        final var query = randomAlphaOfLength(3);
+        final var apikey = randomAlphaOfLength(3);
+        final var request = createRequest(
+            TARGET_URI,
+            AzureAiStudioProvider.COHERE,
+            AzureAiStudioEndpointType.TOKEN,
+            apikey,
+            null,
+            TOP_N,
+            query,
+            input
+        );
+        final var httpPost = getHttpPost(request, apikey);
+        final var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(3));
+        assertThat(requestMap.get(QUERY), is(query));
+        assertThat(requestMap.get(INPUT), is(List.of(input)));
+        assertThat(requestMap.get(TOP_N_FIELD), is(TOP_N));
+    }
+
+    public void testCreateRequest_WithCohereProviderTokenEndpoint_WithReturnDocumentsParam() throws IOException {
+        final var input = randomAlphaOfLength(3);
+        final var query = randomAlphaOfLength(3);
+        final var apikey = randomAlphaOfLength(3);
+        final var request = createRequest(
+            TARGET_URI,
+            AzureAiStudioProvider.COHERE,
+            AzureAiStudioEndpointType.TOKEN,
+            apikey,
+            true,
+            null,
+            query,
+            input
+        );
+        final var httpPost = getHttpPost(request, apikey);
+        final var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(3));
+        assertThat(requestMap.get(QUERY), is(query));
+        assertThat(requestMap.get(INPUT), is(List.of(input)));
+        assertThat(requestMap.get(RETURN_DOCUMENTS_FIELD), is(true));
+    }
+
+    private HttpPost getHttpPost(AzureAiStudioRerankRequest request, String apikey) {
+        final var httpRequest = request.createHttpRequest();
+
+        final var httpPost = validateRequestUrlAndContentType(httpRequest, TARGET_URI + "/v1/rerank");
+        validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, apikey);
+        return httpPost;
+    }
+
+    private HttpPost validateRequestUrlAndContentType(HttpRequest request, String expectedUrl) {
+        assertThat(request.httpRequestBase(), instanceOf(HttpPost.class));
+        final var httpPost = (HttpPost) request.httpRequestBase();
+        assertThat(httpPost.getURI().toString(), is(expectedUrl));
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        return httpPost;
+    }
+
+    private void validateRequestApiKey(
+        HttpPost httpPost,
+        AzureAiStudioProvider provider,
+        AzureAiStudioEndpointType endpointType,
+        String apiKey
+    ) {
+        if (endpointType == AzureAiStudioEndpointType.TOKEN) {
+            if (provider == AzureAiStudioProvider.OPENAI) {
+                assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is(apiKey));
+            } else {
+                assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(apiKey));
+            }
+        } else {
+            assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey));
+        }
+    }
+
+    public static AzureAiStudioRerankRequest createRequest(
+        String target,
+        AzureAiStudioProvider provider,
+        AzureAiStudioEndpointType endpointType,
+        String apiKey,
+        String query,
+        String input
+    ) {
+        return createRequest(target, provider, endpointType, apiKey, null, null, query, input);
+    }
+
+    public static AzureAiStudioRerankRequest createRequest(
+        String target,
+        AzureAiStudioProvider provider,
+        AzureAiStudioEndpointType endpointType,
+        String apiKey,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        String query,
+        String input
+    ) {
+        final var model = AzureAiStudioRerankModelTests.createModel(
+            "id",
+            target,
+            provider,
+            endpointType,
+            apiKey,
+            returnDocuments,
+            topN,
+            null
+        );
+        return new AzureAiStudioRerankRequest(model, query, List.of(input), returnDocuments, topN);
+    }
+}

+ 130 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankModelTests.java

@@ -0,0 +1,130 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.net.URISyntaxException;
+
+import static org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
+
+public class AzureAiStudioRerankModelTests extends ESTestCase {
+    private static final String MODEL_ID = "id";
+    private static final String TARGET_URI = "http://testtarget.local";
+    private static final String API_KEY = "apikey";
+    private static final Integer TOP_N = 1;
+    private static final Integer TOP_N_OVERRIDE = 2;
+
+    public void testOverrideWith_OverridesWithoutValues() {
+        final var model = createModel(
+            MODEL_ID,
+            TARGET_URI,
+            AzureAiStudioProvider.COHERE,
+            AzureAiStudioEndpointType.TOKEN,
+            API_KEY,
+            true,
+            TOP_N,
+            null
+        );
+        final var requestTaskSettingsMap = getTaskSettingsMap(null, null);
+        final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettingsMap);
+
+        assertThat(overriddenModel, sameInstance(overriddenModel));
+    }
+
+    public void testOverrideWith_returnDocuments() {
+        final var model = createModel(
+            MODEL_ID,
+            TARGET_URI,
+            AzureAiStudioProvider.COHERE,
+            AzureAiStudioEndpointType.TOKEN,
+            API_KEY,
+            true,
+            null,
+            null
+        );
+        final var requestTaskSettings = AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(false, null);
+        final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettings);
+
+        assertThat(
+            overriddenModel,
+            is(createModel(MODEL_ID, TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, API_KEY, false, null, null))
+        );
+    }
+
+    public void testOverrideWith_topN() {
+        final var model = createModel(
+            MODEL_ID,
+            TARGET_URI,
+            AzureAiStudioProvider.COHERE,
+            AzureAiStudioEndpointType.TOKEN,
+            API_KEY,
+            null,
+            TOP_N,
+            null
+        );
+        final var requestTaskSettings = AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(null, TOP_N_OVERRIDE);
+        final var overriddenModel = AzureAiStudioRerankModel.of(model, requestTaskSettings);
+        assertThat(
+            overriddenModel,
+            is(
+                createModel(
+                    MODEL_ID,
+                    TARGET_URI,
+                    AzureAiStudioProvider.COHERE,
+                    AzureAiStudioEndpointType.TOKEN,
+                    API_KEY,
+                    null,
+                    TOP_N_OVERRIDE,
+                    null
+                )
+            )
+        );
+    }
+
+    public void testSetsProperUrlForCohereTokenModel() throws URISyntaxException {
+        final var model = createModel(MODEL_ID, TARGET_URI, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, API_KEY);
+        assertThat(model.getEndpointUri().toString(), is(TARGET_URI + "/v1/rerank"));
+    }
+
+    public static AzureAiStudioRerankModel createModel(
+        String id,
+        String target,
+        AzureAiStudioProvider provider,
+        AzureAiStudioEndpointType endpointType,
+        String apiKey
+    ) {
+        return createModel(id, target, provider, endpointType, apiKey, null, null, null);
+    }
+
+    public static AzureAiStudioRerankModel createModel(
+        String id,
+        String target,
+        AzureAiStudioProvider provider,
+        AzureAiStudioEndpointType endpointType,
+        String apiKey,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        @Nullable RateLimitSettings rateLimitSettings
+    ) {
+        return new AzureAiStudioRerankModel(
+            id,
+            new AzureAiStudioRerankServiceSettings(target, provider, endpointType, rateLimitSettings),
+            new AzureAiStudioRerankTaskSettings(returnDocuments, topN),
+            new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+        );
+    }
+}

+ 83 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankRequestTaskSettingsTests.java

@@ -0,0 +1,83 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.test.ESTestCase;
+import org.hamcrest.MatcherAssert;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class AzureAiStudioRerankRequestTaskSettingsTests extends ESTestCase {
+    private static final String INVALID_FIELD_TYPE_STRING = "invalid";
+    private static final boolean RETURN_DOCUMENTS = true;
+    private static final int TOP_N = 2;
+
+    public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() {
+        assertThat(
+            AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of())),
+            is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS)
+        );
+    }
+
+    public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() {
+        assertThat(
+            AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))),
+            is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS)
+        );
+    }
+
+    public void testFromMap_ReturnsReturnDocuments() {
+        assertThat(
+            AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(RETURN_DOCUMENTS_FIELD, RETURN_DOCUMENTS))),
+            is(new AzureAiStudioRerankRequestTaskSettings(RETURN_DOCUMENTS, null))
+        );
+    }
+
+    public void testFromMap_ReturnsTopN() {
+        assertThat(
+            AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_N_FIELD, TOP_N))),
+            is(new AzureAiStudioRerankRequestTaskSettings(null, TOP_N))
+        );
+    }
+
+    public void testFromMap_ReturnDocumentsIsInvalidValue_ThrowsValidationException() {
+        assertThrowsValidationExceptionIfStringValueProvidedFor(RETURN_DOCUMENTS_FIELD);
+    }
+
+    public void testFromMap_TopNIsInvalidValue_ThrowsValidationException() {
+        assertThrowsValidationExceptionIfStringValueProvidedFor(TOP_N_FIELD);
+    }
+
+    private void assertThrowsValidationExceptionIfStringValueProvidedFor(String field) {
+        final var thrownException = expectThrows(
+            ValidationException.class,
+            () -> AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(field, INVALID_FIELD_TYPE_STRING)))
+        );
+
+        MatcherAssert.assertThat(
+            thrownException.getMessage(),
+            containsString(
+                Strings.format(
+                    "field ["
+                        + field
+                        + "] is not of the expected type. The value ["
+                        + INVALID_FIELD_TYPE_STRING
+                        + "] cannot be converted to a "
+                )
+            )
+        );
+    }
+}

+ 123 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankServiceSettingsTests.java

@@ -0,0 +1,123 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
+import org.hamcrest.CoreMatchers;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType.TOKEN;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider.COHERE;
+import static org.hamcrest.Matchers.is;
+
+public class AzureAiStudioRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase<AzureAiStudioRerankServiceSettings> {
+    private static final String TARGET_URI = "http://testtarget.local";
+
+    public void testFromMap_Request_CreatesSettingsCorrectly() {
+        final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(
+            createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()),
+            ConfigurationParseContext.REQUEST
+        );
+
+        assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, null)));
+    }
+
+    public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() {
+        final var settingsMap = createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name());
+        settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3)));
+
+        final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST);
+
+        assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3))));
+    }
+
+    public void testFromMap_Persistent_CreatesSettingsCorrectly() {
+        final var serviceSettings = AzureAiStudioRerankServiceSettings.fromMap(
+            createRequestSettingsMap(TARGET_URI, COHERE.name(), TOKEN.name()),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(serviceSettings, is(new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, null)));
+    }
+
+    public void testToXContent_WritesAllValues() throws IOException {
+        final var settings = new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3));
+        final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        settings.toXContent(builder, null);
+        final String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, CoreMatchers.is("""
+            {"target":"http://testtarget.local","provider":"cohere","endpoint_type":"token",""" + """
+            "rate_limit":{"requests_per_minute":3}}"""));
+    }
+
+    public void testToFilteredXContent_WritesAllValues() throws IOException {
+        final var settings = new AzureAiStudioRerankServiceSettings(TARGET_URI, COHERE, TOKEN, new RateLimitSettings(3));
+        final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        final var filteredXContent = settings.getFilteredXContentObject();
+        filteredXContent.toXContent(builder, null);
+        final String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, CoreMatchers.is("""
+            {"target":"http://testtarget.local","provider":"cohere","endpoint_type":"token",""" + """
+            "rate_limit":{"requests_per_minute":3}}"""));
+    }
+
+    public static HashMap<String, Object> createRequestSettingsMap(String target, String provider, String endpointType) {
+        return new HashMap<>(Map.of(TARGET_FIELD, target, PROVIDER_FIELD, provider, ENDPOINT_TYPE_FIELD, endpointType));
+    }
+
+    @Override
+    protected Writeable.Reader<AzureAiStudioRerankServiceSettings> instanceReader() {
+        return AzureAiStudioRerankServiceSettings::new;
+    }
+
+    @Override
+    protected AzureAiStudioRerankServiceSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AzureAiStudioRerankServiceSettings mutateInstance(AzureAiStudioRerankServiceSettings instance) throws IOException {
+        return randomValueOtherThan(instance, AzureAiStudioRerankServiceSettingsTests::createRandom);
+    }
+
+    @Override
+    protected AzureAiStudioRerankServiceSettings mutateInstanceForVersion(
+        AzureAiStudioRerankServiceSettings instance,
+        TransportVersion version
+    ) {
+        return instance;
+    }
+
+    private static AzureAiStudioRerankServiceSettings createRandom() {
+        return new AzureAiStudioRerankServiceSettings(
+            randomAlphaOfLength(10),
+            randomFrom(AzureAiStudioProvider.values()),
+            randomFrom(AzureAiStudioEndpointType.values()),
+            RateLimitSettingsTests.createRandom()
+        );
+    }
+}

+ 230 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/rerank/AzureAiStudioRerankTaskSettingsTests.java

@@ -0,0 +1,230 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.rerank;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.RETURN_DOCUMENTS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_N_FIELD;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class AzureAiStudioRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase<AzureAiStudioRerankTaskSettings> {
+    private static final String INVALID_FIELD_TYPE_STRING = "invalid";
+
+    public void testIsEmpty() {
+        final var randomSettings = createRandom();
+        final var stringRep = Strings.toString(randomSettings);
+        assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
+    }
+
+    public void testUpdatedTaskSettings_WithAllValues() {
+        final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
+        AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
+        assertUpdateSettings(newSettings, initialSettings);
+    }
+
+    public void testUpdatedTaskSettings_WithReturnDocumentsValue() {
+        final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
+        AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
+        assertUpdateSettings(newSettings, initialSettings);
+    }
+
+    public void testUpdatedTaskSettings_WithTopNValue() {
+        final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
+        AzureAiStudioRerankTaskSettings newSettings = createRandom(initialSettings);
+        assertUpdateSettings(newSettings, initialSettings);
+    }
+
+    public void testUpdatedTaskSettings_WithNoValues() {
+        AzureAiStudioRerankTaskSettings initialSettings = createRandom();
+        final AzureAiStudioRerankTaskSettings newSettings = new AzureAiStudioRerankTaskSettings(null, null);
+        assertUpdateSettings(newSettings, initialSettings);
+    }
+
+    private void assertUpdateSettings(AzureAiStudioRerankTaskSettings newSettings, AzureAiStudioRerankTaskSettings initialSettings) {
+        final var settingsMap = new HashMap<String, Object>();
+        if (newSettings.returnDocuments() != null) settingsMap.put(RETURN_DOCUMENTS_FIELD, newSettings.returnDocuments());
+        if (newSettings.topN() != null) settingsMap.put(TOP_N_FIELD, newSettings.topN());
+
+        final AzureAiStudioRerankTaskSettings updatedSettings = (AzureAiStudioRerankTaskSettings) initialSettings.updatedTaskSettings(
+            Collections.unmodifiableMap(settingsMap)
+        );
+        assertEquals(
+            newSettings.returnDocuments() == null ? initialSettings.returnDocuments() : newSettings.returnDocuments(),
+            updatedSettings.returnDocuments()
+        );
+        assertEquals(newSettings.topN() == null ? initialSettings.topN() : newSettings.topN(), updatedSettings.topN());
+    }
+
+    public void testFromMap_AllValues() {
+        assertEquals(new AzureAiStudioRerankTaskSettings(true, 2), AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, 2)));
+    }
+
+    public void testFromMap_ReturnDocuments() {
+        assertEquals(
+            new AzureAiStudioRerankTaskSettings(true, null),
+            AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, null))
+        );
+    }
+
+    public void testFromMap_TopN() {
+        assertEquals(new AzureAiStudioRerankTaskSettings(null, 2), AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(null, 2)));
+    }
+
+    public void testFromMap_ReturnDocumentsIsInvalidValue_ThrowsValidationException() {
+        getTaskSettingsMap(true, 2).put(RETURN_DOCUMENTS_FIELD, INVALID_FIELD_TYPE_STRING);
+        assertThrowsValidationExceptionIfStringValueProvidedFor(RETURN_DOCUMENTS_FIELD);
+    }
+
+    public void testFromMap_TopNIsInvalidValue_ThrowsValidationException() {
+        getTaskSettingsMap(true, 2).put(TOP_N_FIELD, INVALID_FIELD_TYPE_STRING);
+        assertThrowsValidationExceptionIfStringValueProvidedFor(TOP_N_FIELD);
+    }
+
+    public void testFromMap_WithNoValues_DoesNotThrowException() {
+        final var taskMap = AzureAiStudioRerankTaskSettings.fromMap(new HashMap<>(Map.of()));
+        assertNull(taskMap.returnDocuments());
+        assertNull(taskMap.topN());
+    }
+
+    public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() {
+        final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, 2));
+        final var overrideSettings = AzureAiStudioRerankTaskSettings.of(settings, AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS);
+        MatcherAssert.assertThat(overrideSettings, is(settings));
+    }
+
+    public void testOverrideWith_UsesReturnDocumentsOverride() {
+        final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(true, null));
+        final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(false, null));
+        final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
+        MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(false, null)));
+    }
+
+    public void testOverrideWith_UsesTopNOverride() {
+        final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(null, 2));
+        final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(null, 1));
+        final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
+        MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(null, 1)));
+    }
+
+    public void testOverrideWith_UsesAllParametersOverride() {
+        final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(false, 2));
+        final var overrideSettings = AzureAiStudioRerankRequestTaskSettings.fromMap(getTaskSettingsMap(true, 1));
+        final var overriddenTaskSettings = AzureAiStudioRerankTaskSettings.of(settings, overrideSettings);
+        MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureAiStudioRerankTaskSettings(true, 1)));
+    }
+
+    public void testToXContent_WithoutParameters() throws IOException {
+        assertThat(getXContentResult(null, null), is("{}"));
+    }
+
+    public void testToXContent_WithReturnDocumentsParameter() throws IOException {
+        assertThat(getXContentResult(true, null), is("""
+            {"return_documents":true}"""));
+    }
+
+    public void testToXContent_WithTopNParameter() throws IOException {
+        assertThat(getXContentResult(null, 2), is("""
+            {"top_n":2}"""));
+    }
+
+    public void testToXContent_WithParameters() throws IOException {
+        assertThat(getXContentResult(true, 2), is("""
+            {"return_documents":true,"top_n":2}"""));
+    }
+
+    private String getXContentResult(Boolean returnDocuments, Integer topN) throws IOException {
+        final var settings = AzureAiStudioRerankTaskSettings.fromMap(getTaskSettingsMap(returnDocuments, topN));
+        final XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        settings.toXContent(builder, null);
+        return Strings.toString(builder);
+    }
+
+    public static Map<String, Object> getTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
+        final var map = new HashMap<String, Object>();
+
+        if (returnDocuments != null) {
+            map.put(RETURN_DOCUMENTS_FIELD, returnDocuments);
+        }
+
+        if (topN != null) {
+            map.put(TOP_N_FIELD, topN);
+        }
+
+        return map;
+    }
+
+    @Override
+    protected Writeable.Reader<AzureAiStudioRerankTaskSettings> instanceReader() {
+        return AzureAiStudioRerankTaskSettings::new;
+    }
+
+    @Override
+    protected AzureAiStudioRerankTaskSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AzureAiStudioRerankTaskSettings mutateInstance(AzureAiStudioRerankTaskSettings instance) throws IOException {
+        return randomValueOtherThan(instance, AzureAiStudioRerankTaskSettingsTests::createRandom);
+    }
+
+    @Override
+    protected AzureAiStudioRerankTaskSettings mutateInstanceForVersion(AzureAiStudioRerankTaskSettings instance, TransportVersion version) {
+        return instance;
+    }
+
+    private static AzureAiStudioRerankTaskSettings createRandom() {
+        return new AzureAiStudioRerankTaskSettings(
+            randomFrom(new Boolean[] { null, randomBoolean() }),
+            randomFrom(new Integer[] { null, randomNonNegativeInt() })
+        );
+    }
+
+    private static AzureAiStudioRerankTaskSettings createRandom(AzureAiStudioRerankTaskSettings settings) {
+        return new AzureAiStudioRerankTaskSettings(
+            randomValueOtherThan(settings.returnDocuments(), () -> randomFrom(new Boolean[] { null, randomBoolean() })),
+            randomValueOtherThan(settings.topN(), () -> randomFrom(new Integer[] { null, randomNonNegativeInt() }))
+        );
+    }
+
+    private void assertThrowsValidationExceptionIfStringValueProvidedFor(String field) {
+        final var thrownException = expectThrows(
+            ValidationException.class,
+            () -> AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(field, INVALID_FIELD_TYPE_STRING)))
+        );
+
+        MatcherAssert.assertThat(
+            thrownException.getMessage(),
+            containsString(
+                Strings.format(
+                    "field ["
+                        + field
+                        + "] is not of the expected type. The value ["
+                        + INVALID_FIELD_TYPE_STRING
+                        + "] cannot be converted to a "
+                )
+            )
+        );
+    }
+}

+ 113 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/response/AzureAiStudioRerankResponseEntityTests.java

@@ -0,0 +1,113 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.response;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class AzureAiStudioRerankResponseEntityTests extends ESTestCase {
+    public void testResponse_WithDocuments() throws IOException {
+        final String responseJson = getResponseJsonWithDocuments();
+
+        final var parsedResults = getParsedResults(responseJson);
+        final var expectedResults = List.of(
+            new RankedDocsResults.RankedDoc(0, 0.1111111F, "test text one"),
+            new RankedDocsResults.RankedDoc(1, 0.2222222F, "test text two")
+        );
+
+        assertThat(parsedResults.getRankedDocs(), is(expectedResults));
+    }
+
+    public void testResponse_NoDocuments() throws IOException {
+        final String responseJson = getResponseJsonNoDocuments();
+
+        final var parsedResults = getParsedResults(responseJson);
+        final var expectedResults = List.of(
+            new RankedDocsResults.RankedDoc(0, 0.1111111F, null),
+            new RankedDocsResults.RankedDoc(1, 0.2222222F, null)
+        );
+
+        assertThat(parsedResults.getRankedDocs(), is(expectedResults));
+    }
+
+    private RankedDocsResults getParsedResults(String responseJson) throws IOException {
+        final var entity = new AzureAiStudioRerankResponseEntity();
+        return (RankedDocsResults) entity.apply(
+            mock(Request.class),
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+    }
+
+    private String getResponseJsonWithDocuments() {
+        return """
+            {
+                "id": "222e59de-c712-40cb-ae87-ecd402d0d2f1",
+                "results": [
+                    {
+                        "document": {
+                            "text": "test text one"
+                        },
+                        "index": 0,
+                        "relevance_score": 0.1111111
+                    },
+                    {
+                        "document": {
+                            "text": "test text two"
+                        },
+                        "index": 1,
+                        "relevance_score": 0.2222222
+                    }
+                ],
+                "meta": {
+                    "api_version": {
+                        "version": "1"
+                    },
+                    "billed_units": {
+                        "search_units": 1
+                    }
+                }
+            }
+            """;
+    }
+
+    private String getResponseJsonNoDocuments() {
+        return """
+            {
+                "id": "222e59de-c712-40cb-ae87-ecd402d0d2f1",
+                "results": [
+                    {
+                        "index": 0,
+                        "relevance_score": 0.1111111
+                    },
+                    {
+                        "index": 1,
+                        "relevance_score": 0.2222222
+                    }
+                ],
+                "meta": {
+                    "api_version": {
+                        "version": "1"
+                    },
+                    "billed_units": {
+                        "search_units": 1
+                    }
+                }
+            }
+            """;
+    }
+}