Browse Source

Adding common rerank options to Perform Inference API (#125239)

* wip

* Adding rerank common options

* Linting

* Linting

* [CI] Auto commit changes from spotless

* Update docs/changelog/125239.yaml

* PR feedback

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Ying Mao 6 months ago
parent
commit
a6f685cc2a
66 changed files with 1248 additions and 229 deletions
  1. 6 0
      docs/changelog/125239.yaml
  2. 2 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  3. 12 8
      server/src/main/java/org/elasticsearch/inference/InferenceService.java
  4. 81 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
  5. 236 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java
  6. 2 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java
  7. 2 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
  8. 2 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java
  9. 3 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
  10. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java
  11. 7 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java
  12. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java
  13. 7 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java
  14. 7 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java
  15. 7 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java
  16. 22 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java
  17. 9 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java
  18. 8 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java
  19. 14 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java
  20. 26 7
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java
  21. 21 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java
  22. 14 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java
  23. 14 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java
  24. 27 9
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java
  25. 22 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java
  26. 26 8
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java
  27. 0 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java
  28. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java
  29. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java
  30. 15 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java
  31. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
  32. 19 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
  33. 14 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
  34. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java
  35. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
  36. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java
  37. 7 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java
  38. 95 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java
  39. 12 11
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java
  40. 68 11
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java
  41. 42 39
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java
  42. 19 10
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java
  43. 65 45
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java
  44. 19 10
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java
  45. 35 17
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java
  46. 2 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java
  47. 2 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java
  48. 8 8
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
  49. 51 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
  50. 10 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
  51. 6 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java
  52. 10 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
  53. 8 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
  54. 14 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
  55. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java
  56. 10 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
  57. 10 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
  58. 2 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java
  59. 6 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
  60. 8 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
  61. 34 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
  62. 6 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
  63. 14 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
  64. 7 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java
  65. 34 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java
  66. 2 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java

+ 6 - 0
docs/changelog/125239.yaml

@@ -0,0 +1,6 @@
+pr: 125239
+summary: Adding common rerank options to Perform Inference API
+area: Machine Learning
+type: enhancement
+issues:
+ - 111273

+ 2 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -155,6 +155,7 @@ public class TransportVersions {
     public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL_8_19 = def(8_841_0_12);
     public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
     public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
+    public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
     public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
     public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
     public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -201,6 +202,7 @@ public class TransportVersions {
     public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00);
     public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00);
     public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00);
+    public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED = def(9_037_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 12 - 8
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -91,18 +91,22 @@ public interface InferenceService extends Closeable {
     /**
      * Perform inference on the model.
      *
-     * @param model        The model
-     * @param query        Inference query, mainly for re-ranking
-     * @param input        Inference input
-     * @param stream       Stream inference results
-     * @param taskSettings Settings in the request to override the model's defaults
-     * @param inputType    For search, ingest etc
-     * @param timeout      The timeout for the request
-     * @param listener     Inference result listener
+     * @param model           The model
+     * @param query           Inference query, mainly for re-ranking
+     * @param returnDocuments For re-ranking task type, whether to return documents
+     * @param topN            For re-ranking task type, how many docs to return
+     * @param input           Inference input
+     * @param stream          Stream inference results
+     * @param taskSettings    Settings in the request to override the model's defaults
+     * @param inputType       For search, ingest etc
+     * @param timeout         The timeout for the request
+     * @param listener        Inference result listener
      */
     void infer(
         Model model,
         @Nullable String query,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
         List<String> input,
         boolean stream,
         Map<String, Object> taskSettings,

+ 81 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

@@ -60,6 +60,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
         public static final ParseField INPUT_TYPE = new ParseField("input_type");
         public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
         public static final ParseField QUERY = new ParseField("query");
+        public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents");
+        public static final ParseField TOP_N = new ParseField("top_n");
         public static final ParseField TIMEOUT = new ParseField("timeout");
 
         static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
@@ -68,6 +70,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
             PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
             PARSER.declareString(Request.Builder::setQuery, QUERY);
+            PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS);
+            PARSER.declareInt(Request.Builder::setTopN, TOP_N);
             PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
         }
 
@@ -89,6 +93,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
         private final TaskType taskType;
         private final String inferenceEntityId;
         private final String query;
+        private final Boolean returnDocuments;
+        private final Integer topN;
         private final List<String> input;
         private final Map<String, Object> taskSettings;
         private final InputType inputType;
@@ -99,6 +105,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             TaskType taskType,
             String inferenceEntityId,
             String query,
+            Boolean returnDocuments,
+            Integer topN,
             List<String> input,
             Map<String, Object> taskSettings,
             InputType inputType,
@@ -109,6 +117,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                 taskType,
                 inferenceEntityId,
                 query,
+                returnDocuments,
+                topN,
                 input,
                 taskSettings,
                 inputType,
@@ -122,6 +132,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             TaskType taskType,
             String inferenceEntityId,
             String query,
+            Boolean returnDocuments,
+            Integer topN,
             List<String> input,
             Map<String, Object> taskSettings,
             InputType inputType,
@@ -133,6 +145,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             this.taskType = taskType;
             this.inferenceEntityId = inferenceEntityId;
             this.query = query;
+            this.returnDocuments = returnDocuments;
+            this.topN = topN;
             this.input = input;
             this.taskSettings = taskSettings;
             this.inputType = inputType;
@@ -164,6 +178,15 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                 this.inferenceTimeout = DEFAULT_TIMEOUT;
             }
 
+            if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
+                || in.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
+                this.returnDocuments = in.readOptionalBoolean();
+                this.topN = in.readOptionalInt();
+            } else {
+                this.returnDocuments = null;
+                this.topN = null;
+            }
+
             // streaming is not supported yet for transport traffic
             this.stream = false;
         }
@@ -184,6 +207,14 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             return query;
         }
 
+        public Boolean getReturnDocuments() {
+            return returnDocuments;
+        }
+
+        public Integer getTopN() {
+            return topN;
+        }
+
         public Map<String, Object> getTaskSettings() {
             return taskSettings;
         }
@@ -225,6 +256,17 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                     e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
                     return e;
                 }
+            } else if (taskType.equals(TaskType.ANY) == false) {
+                if (returnDocuments != null) {
+                    var e = new ActionRequestValidationException();
+                    e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType));
+                    return e;
+                }
+                if (topN != null) {
+                    var e = new ActionRequestValidationException();
+                    e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType));
+                    return e;
+                }
             }
 
             if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
@@ -258,6 +300,12 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                 out.writeOptionalString(query);
                 out.writeTimeValue(inferenceTimeout);
             }
+
+            if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
+                || out.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
+                out.writeOptionalBoolean(returnDocuments);
+                out.writeOptionalInt(topN);
+            }
         }
 
         // default for easier testing
@@ -283,6 +331,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                 && taskType == request.taskType
                 && Objects.equals(inferenceEntityId, request.inferenceEntityId)
                 && Objects.equals(query, request.query)
+                && Objects.equals(returnDocuments, request.returnDocuments)
+                && Objects.equals(topN, request.topN)
                 && Objects.equals(input, request.input)
                 && Objects.equals(taskSettings, request.taskSettings)
                 && inputType == request.inputType
@@ -296,6 +346,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                 taskType,
                 inferenceEntityId,
                 query,
+                returnDocuments,
+                topN,
                 input,
                 taskSettings,
                 inputType,
@@ -312,6 +364,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             private InputType inputType = InputType.UNSPECIFIED;
             private Map<String, Object> taskSettings = Map.of();
             private String query;
+            private Boolean returnDocuments;
+            private Integer topN;
             private TimeValue timeout = DEFAULT_TIMEOUT;
             private boolean stream = false;
             private InferenceContext context;
@@ -338,6 +392,16 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                 return this;
             }
 
+            public Builder setReturnDocuments(Boolean returnDocuments) {
+                this.returnDocuments = returnDocuments;
+                return this;
+            }
+
+            public Builder setTopN(Integer topN) {
+                this.topN = topN;
+                return this;
+            }
+
             public Builder setInputType(InputType inputType) {
                 this.inputType = inputType;
                 return this;
@@ -373,7 +437,19 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             }
 
             public Request build() {
-                return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
+                return new Request(
+                    taskType,
+                    inferenceEntityId,
+                    query,
+                    returnDocuments,
+                    topN,
+                    input,
+                    taskSettings,
+                    inputType,
+                    timeout,
+                    stream,
+                    context
+                );
             }
         }
 
@@ -384,6 +460,10 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
                 + this.getInferenceEntityId()
                 + ", query="
                 + this.getQuery()
+                + ", returnDocuments="
+                + this.getReturnDocuments()
+                + ", topN="
+                + this.getTopN()
                 + ", input="
                 + this.getInput()
                 + ", taskSettings="

+ 236 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

@@ -44,6 +44,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             randomFrom(TaskType.values()),
             randomAlphaOfLength(6),
             randomAlphaOfLengthOrNull(10),
+            randomBoolean(),
+            randomIntBetween(0, 10),
             randomList(1, 5, () -> randomAlphaOfLength(8)),
             randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
             randomFrom(InputType.values()),
@@ -85,6 +87,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.TEXT_EMBEDDING,
             "model",
             null,
+            null,
+            null,
             List.of("input"),
             null,
             null,
@@ -100,6 +104,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.RERANK,
             "model",
             "query",
+            Boolean.TRUE,
+            34,
             List.of("input"),
             null,
             null,
@@ -119,6 +125,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             null,
             null,
             null,
+            null,
+            null,
             false
         );
         ActionRequestValidationException inputNullError = inputNullRequest.validate();
@@ -131,6 +139,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.TEXT_EMBEDDING,
             "model",
             null,
+            null,
+            null,
             List.of(),
             null,
             null,
@@ -142,11 +152,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
         assertThat(inputEmptyError.getMessage(), is("Validation Failed: 1: Field [input] cannot be an empty array;"));
     }
 
+    public void testValidation_TextEmbedding_WithReturnDocument() {
+        InferenceAction.Request inputRequest = new InferenceAction.Request(
+            TaskType.TEXT_EMBEDDING,
+            "model",
+            null,
+            Boolean.TRUE,
+            null,
+            List.of("input"),
+            null,
+            null,
+            null,
+            false
+        );
+        ActionRequestValidationException inputError = inputRequest.validate();
+        assertNotNull(inputError);
+        assertThat(
+            inputError.getMessage(),
+            is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [text_embedding];")
+        );
+    }
+
+    public void testValidation_TextEmbedding_WithTopN() {
+        InferenceAction.Request inputRequest = new InferenceAction.Request(
+            TaskType.TEXT_EMBEDDING,
+            "model",
+            null,
+            null,
+            12,
+            List.of("input"),
+            null,
+            null,
+            null,
+            false
+        );
+        ActionRequestValidationException inputError = inputRequest.validate();
+        assertNotNull(inputError);
+        assertThat(inputError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [text_embedding];"));
+    }
+
     public void testValidation_Rerank_Null() {
         InferenceAction.Request queryNullRequest = new InferenceAction.Request(
             TaskType.RERANK,
             "model",
             null,
+            null,
+            null,
             List.of("input"),
             null,
             null,
@@ -163,6 +214,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.RERANK,
             "model",
             "",
+            null,
+            null,
             List.of("input"),
             null,
             null,
@@ -179,6 +232,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.RERANK,
             "model",
             "query",
+            null,
+            null,
             List.of("input"),
             null,
             InputType.SEARCH,
@@ -195,6 +250,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.SPARSE_EMBEDDING,
             "model",
             "",
+            null,
+            null,
             List.of("input"),
             null,
             InputType.SEARCH,
@@ -209,11 +266,56 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
         );
     }
 
+    public void testValidation_SparseEmbedding_WithReturnDocument() {
+        InferenceAction.Request queryRequest = new InferenceAction.Request(
+            TaskType.SPARSE_EMBEDDING,
+            "model",
+            "",
+            Boolean.FALSE,
+            null,
+            List.of("input"),
+            null,
+            null,
+            null,
+            false
+        );
+        ActionRequestValidationException queryError = queryRequest.validate();
+        assertNotNull(queryError);
+        assertThat(
+            queryError.getMessage(),
+            is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [sparse_embedding];")
+        );
+
+    }
+
+    public void testValidation_SparseEmbedding_WithTopN() {
+        InferenceAction.Request queryRequest = new InferenceAction.Request(
+            TaskType.SPARSE_EMBEDDING,
+            "model",
+            "",
+            null,
+            22,
+            List.of("input"),
+            null,
+            null,
+            null,
+            false
+        );
+        ActionRequestValidationException queryError = queryRequest.validate();
+        assertNotNull(queryError);
+        assertThat(
+            queryError.getMessage(),
+            is("Validation Failed: 1: Field [top_n] cannot be specified for task type [sparse_embedding];")
+        );
+    }
+
     public void testValidation_Completion_WithInputType() {
         InferenceAction.Request queryRequest = new InferenceAction.Request(
             TaskType.COMPLETION,
             "model",
             "",
+            null,
+            null,
             List.of("input"),
             null,
             InputType.SEARCH,
@@ -225,11 +327,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
         assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];"));
     }
 
+    public void testValidation_Completion_WithReturnDocuments() {
+        InferenceAction.Request queryRequest = new InferenceAction.Request(
+            TaskType.COMPLETION,
+            "model",
+            "",
+            Boolean.TRUE,
+            null,
+            List.of("input"),
+            null,
+            null,
+            null,
+            false
+        );
+        ActionRequestValidationException queryError = queryRequest.validate();
+        assertNotNull(queryError);
+        assertThat(
+            queryError.getMessage(),
+            is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [completion];")
+        );
+    }
+
+    public void testValidation_Completion_WithTopN() {
+        InferenceAction.Request queryRequest = new InferenceAction.Request(
+            TaskType.COMPLETION,
+            "model",
+            "",
+            null,
+            77,
+            List.of("input"),
+            null,
+            null,
+            null,
+            false
+        );
+        ActionRequestValidationException queryError = queryRequest.validate();
+        assertNotNull(queryError);
+        assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [completion];"));
+    }
+
     public void testValidation_ChatCompletion_WithInputType() {
         InferenceAction.Request queryRequest = new InferenceAction.Request(
             TaskType.CHAT_COMPLETION,
             "model",
             "",
+            null,
+            null,
             List.of("input"),
             null,
             InputType.SEARCH,
@@ -244,6 +387,45 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
         );
     }
 
+    public void testValidation_ChatCompletion_WithReturnDocuments() {
+        InferenceAction.Request queryRequest = new InferenceAction.Request(
+            TaskType.CHAT_COMPLETION,
+            "model",
+            "",
+            Boolean.TRUE,
+            null,
+            List.of("input"),
+            null,
+            null,
+            null,
+            false
+        );
+        ActionRequestValidationException queryError = queryRequest.validate();
+        assertNotNull(queryError);
+        assertThat(
+            queryError.getMessage(),
+            is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [chat_completion];")
+        );
+    }
+
+    public void testValidation_ChatCompletion_WithTopN() {
+        InferenceAction.Request queryRequest = new InferenceAction.Request(
+            TaskType.CHAT_COMPLETION,
+            "model",
+            "",
+            null,
+            11,
+            List.of("input"),
+            null,
+            InputType.SEARCH,
+            null,
+            false
+        );
+        ActionRequestValidationException queryError = queryRequest.validate();
+        assertNotNull(queryError);
+        assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [chat_completion];"));
+    }
+
     public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
         String singleInputRequest = """
             {
@@ -271,6 +453,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                     nextTask,
                     instance.getInferenceEntityId(),
                     instance.getQuery(),
+                    instance.getReturnDocuments(),
+                    instance.getTopN(),
                     instance.getInput(),
                     instance.getTaskSettings(),
                     instance.getInputType(),
@@ -283,6 +467,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                 instance.getTaskType(),
                 instance.getInferenceEntityId() + "foo",
                 instance.getQuery(),
+                instance.getReturnDocuments(),
+                instance.getTopN(),
                 instance.getInput(),
                 instance.getTaskSettings(),
                 instance.getInputType(),
@@ -297,6 +483,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                     instance.getTaskType(),
                     instance.getInferenceEntityId(),
                     instance.getQuery(),
+                    instance.getReturnDocuments(),
+                    instance.getTopN(),
                     changedInputs,
                     instance.getTaskSettings(),
                     instance.getInputType(),
@@ -317,6 +505,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                     instance.getTaskType(),
                     instance.getInferenceEntityId(),
                     instance.getQuery(),
+                    instance.getReturnDocuments(),
+                    instance.getTopN(),
                     instance.getInput(),
                     taskSettings,
                     instance.getInputType(),
@@ -331,6 +521,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                     instance.getTaskType(),
                     instance.getInferenceEntityId(),
                     instance.getQuery(),
+                    instance.getReturnDocuments(),
+                    instance.getTopN(),
                     instance.getInput(),
                     instance.getTaskSettings(),
                     nextInputType,
@@ -343,6 +535,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                 instance.getTaskType(),
                 instance.getInferenceEntityId(),
                 instance.getQuery() == null ? randomAlphaOfLength(10) : instance.getQuery() + randomAlphaOfLength(1),
+                instance.getReturnDocuments(),
+                instance.getTopN(),
                 instance.getInput(),
                 instance.getTaskSettings(),
                 instance.getInputType(),
@@ -360,6 +554,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                     instance.getTaskType(),
                     instance.getInferenceEntityId(),
                     instance.getQuery(),
+                    instance.getReturnDocuments(),
+                    instance.getTopN(),
                     instance.getInput(),
                     instance.getTaskSettings(),
                     instance.getInputType(),
@@ -374,6 +570,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                     instance.getTaskType(),
                     instance.getInferenceEntityId(),
                     instance.getQuery(),
+                    instance.getReturnDocuments(),
+                    instance.getTopN(),
                     instance.getInput(),
                     instance.getTaskSettings(),
                     instance.getInputType(),
@@ -395,6 +593,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                 instance.getTaskType(),
                 instance.getInferenceEntityId(),
                 null,
+                null,
+                null,
                 instance.getInput().subList(0, 1),
                 instance.getTaskSettings(),
                 InputType.UNSPECIFIED,
@@ -406,6 +606,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                 instance.getTaskType(),
                 instance.getInferenceEntityId(),
                 null,
+                null,
+                null,
                 instance.getInput(),
                 instance.getTaskSettings(),
                 InputType.UNSPECIFIED,
@@ -420,6 +622,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                         instance.getTaskType(),
                         instance.getInferenceEntityId(),
                         null,
+                        null,
+                        null,
                         instance.getInput(),
                         instance.getTaskSettings(),
                         InputType.INGEST,
@@ -432,6 +636,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                             instance.getTaskType(),
                             instance.getInferenceEntityId(),
                             null,
+                            null,
+                            null,
                             instance.getInput(),
                             instance.getTaskSettings(),
                             InputType.UNSPECIFIED,
@@ -443,6 +649,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                             instance.getTaskType(),
                             instance.getInferenceEntityId(),
                             null,
+                            null,
+                            null,
                             instance.getInput(),
                             instance.getTaskSettings(),
                             instance.getInputType(),
@@ -455,6 +663,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                                 instance.getTaskType(),
                                 instance.getInferenceEntityId(),
                                 instance.getQuery(),
+                                null,
+                                null,
                                 instance.getInput(),
                                 instance.getTaskSettings(),
                                 instance.getInputType(),
@@ -462,9 +672,24 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
                                 false,
                                 InferenceContext.EMPTY_INSTANCE
                             );
-                        } else {
-                            mutated = instance;
-                        }
+                        } else if (version.before(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
+                            && version.isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19) == false) {
+                                mutated = new InferenceAction.Request(
+                                    instance.getTaskType(),
+                                    instance.getInferenceEntityId(),
+                                    instance.getQuery(),
+                                    null,
+                                    null,
+                                    instance.getInput(),
+                                    instance.getTaskSettings(),
+                                    instance.getInputType(),
+                                    instance.getInferenceTimeout(),
+                                    false,
+                                    instance.getContext()
+                                );
+                            } else {
+                                mutated = instance;
+                            }
 
         // We always assume that a request has been rerouted, if it came from a node without adaptive rate limiting
         if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
@@ -481,6 +706,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.TEXT_EMBEDDING,
             "model",
             null,
+            null,
+            null,
             List.of(),
             Map.of(),
             InputType.UNSPECIFIED,
@@ -503,6 +730,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.TEXT_EMBEDDING,
             "model",
             null,
+            null,
+            null,
             List.of(),
             Map.of(),
             InputType.INGEST,
@@ -525,6 +754,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.TEXT_EMBEDDING,
             "model",
             null,
+            null,
+            null,
             List.of("input"),
             Map.of(),
             InputType.UNSPECIFIED,
@@ -548,6 +779,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
             TaskType.TEXT_EMBEDDING,
             "model",
             null,
+            null,
+            null,
             List.of("input"),
             Map.of(),
             InputType.UNSPECIFIED,

+ 2 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

@@ -110,6 +110,8 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
         public void infer(
             Model model,
             @Nullable String query,
+            @Nullable Boolean returnDocuments,
+            @Nullable Integer topN,
             List<String> input,
             boolean stream,
             Map<String, Object> taskSettings,

+ 2 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

@@ -102,6 +102,8 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
         public void infer(
             Model model,
             @Nullable String query,
+            @Nullable Boolean returnDocuments,
+            @Nullable Integer topN,
             List<String> input,
             boolean stream,
             Map<String, Object> taskSettings,

+ 2 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

@@ -103,6 +103,8 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte
         public void infer(
             Model model,
             @Nullable String query,
+            @Nullable Boolean returnDocuments,
+            @Nullable Integer topN,
             List<String> input,
             boolean stream,
             Map<String, Object> taskSettings,

+ 3 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -103,6 +104,8 @@ public class TestStreamingCompletionServiceExtension implements InferenceService
         public void infer(
             Model model,
             String query,
+            @Nullable Boolean returnDocuments,
+            @Nullable Integer topN,
             List<String> input,
             boolean stream,
             Map<String, Object> taskSettings,

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

@@ -77,6 +77,8 @@ public class TransportInferenceAction extends BaseTransportInferenceAction<Infer
         service.infer(
             model,
             request.getQuery(),
+            request.getReturnDocuments(),
+            request.getTopN(),
             request.getInput(),
             request.isStreaming(),
             request.getTaskSettings(),

+ 7 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/voyageai/VoyageAIActionCreator.java

@@ -75,7 +75,13 @@ public class VoyageAIActionCreator implements VoyageAIActionVisitor {
             serviceComponents.threadPool(),
             overriddenModel,
             RERANK_HANDLER,
-            (rerankInput) -> new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
+            (rerankInput) -> new VoyageAIRerankRequest(
+                rerankInput.getQuery(),
+                rerankInput.getChunks(),
+                rerankInput.getReturnDocuments(),
+                rerankInput.getTopN(),
+                model
+            ),
             QueryAndDocsInputs.class
         );
 

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java

@@ -69,6 +69,8 @@ public class AlibabaCloudSearchRerankRequestManager extends AlibabaCloudSearchRe
             account,
             rerankInput.getQuery(),
             rerankInput.getChunks(),
+            rerankInput.getReturnDocuments(),
+            rerankInput.getTopN(),
             model
         );
 

+ 7 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java

@@ -49,7 +49,13 @@ public class CohereRerankRequestManager extends CohereRequestManager {
         ActionListener<InferenceServiceResults> listener
     ) {
         var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
-        CohereRerankRequest request = new CohereRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
+        CohereRerankRequest request = new CohereRerankRequest(
+            rerankInput.getQuery(),
+            rerankInput.getChunks(),
+            rerankInput.getReturnDocuments(),
+            rerankInput.getTopN(),
+            model
+        );
 
         execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
     }

+ 7 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java

@@ -62,7 +62,13 @@ public class GoogleVertexAiRerankRequestManager extends GoogleVertexAiRequestMan
         ActionListener<InferenceServiceResults> listener
     ) {
         var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
-        GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
+        GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(
+            rerankInput.getQuery(),
+            rerankInput.getChunks(),
+            rerankInput.getReturnDocuments(),
+            rerankInput.getTopN(),
+            model
+        );
 
         execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
     }

+ 7 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java

@@ -49,7 +49,13 @@ public class JinaAIRerankRequestManager extends JinaAIRequestManager {
         ActionListener<InferenceServiceResults> listener
     ) {
         var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
-        JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
+        JinaAIRerankRequest request = new JinaAIRerankRequest(
+            rerankInput.getQuery(),
+            rerankInput.getChunks(),
+            rerankInput.getReturnDocuments(),
+            rerankInput.getTopN(),
+            model
+        );
 
         execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
     }

+ 22 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

@@ -7,6 +7,8 @@
 
 package org.elasticsearch.xpack.inference.external.http.sender;
 
+import org.elasticsearch.core.Nullable;
+
 import java.util.List;
 import java.util.Objects;
 
@@ -22,15 +24,25 @@ public class QueryAndDocsInputs extends InferenceInputs {
 
     private final String query;
     private final List<String> chunks;
+    private final Boolean returnDocuments;
+    private final Integer topN;
 
     public QueryAndDocsInputs(String query, List<String> chunks) {
-        this(query, chunks, false);
+        this(query, chunks, null, null, false);
     }
 
-    public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
+    public QueryAndDocsInputs(
+        String query,
+        List<String> chunks,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        boolean stream
+    ) {
         super(stream);
         this.query = Objects.requireNonNull(query);
         this.chunks = Objects.requireNonNull(chunks);
+        this.returnDocuments = returnDocuments;
+        this.topN = topN;
     }
 
     public String getQuery() {
@@ -41,6 +53,14 @@ public class QueryAndDocsInputs extends InferenceInputs {
         return chunks;
     }
 
+    public Boolean getReturnDocuments() {
+        return returnDocuments;
+    }
+
+    public Integer getTopN() {
+        return topN;
+    }
+
     public int inputSize() {
         return chunks.size();
     }

+ 9 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java

@@ -12,6 +12,7 @@ import org.apache.http.client.methods.HttpPost;
 import org.apache.http.client.utils.URIBuilder;
 import org.apache.http.entity.ByteArrayEntity;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
 import org.elasticsearch.xpack.inference.external.request.HttpRequest;
@@ -32,6 +33,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
     private final AlibabaCloudSearchAccount account;
     private final String query;
     private final List<String> input;
+    private final Boolean returnDocuments;
+    private final Integer topN;
     private final URI uri;
     private final AlibabaCloudSearchRerankTaskSettings taskSettings;
     private final String model;
@@ -44,6 +47,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
         AlibabaCloudSearchAccount account,
         String query,
         List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
         AlibabaCloudSearchRerankModel rerankModel
     ) {
         Objects.requireNonNull(rerankModel);
@@ -51,6 +56,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
         this.account = Objects.requireNonNull(account);
         this.query = Objects.requireNonNull(query);
         this.input = Objects.requireNonNull(input);
+        this.returnDocuments = returnDocuments;
+        this.topN = topN;
         taskSettings = rerankModel.getTaskSettings();
         model = rerankModel.getServiceSettings().getCommonSettings().modelId();
         host = rerankModel.getServiceSettings().getCommonSettings().getHost();
@@ -67,7 +74,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
         HttpPost httpPost = new HttpPost(uri);
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, taskSettings)).getBytes(StandardCharsets.UTF_8)
+            Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, returnDocuments, topN, taskSettings))
+                .getBytes(StandardCharsets.UTF_8)
         );
         httpPost.setEntity(byteEntity);
 

+ 8 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
 
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
@@ -15,9 +16,13 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Objects;
 
-public record AlibabaCloudSearchRerankRequestEntity(String query, List<String> input, AlibabaCloudSearchRerankTaskSettings taskSettings)
-    implements
-        ToXContentObject {
+public record AlibabaCloudSearchRerankRequestEntity(
+    String query,
+    List<String> input,
+    @Nullable Boolean returnDocuments,
+    @Nullable Integer topN,
+    AlibabaCloudSearchRerankTaskSettings taskSettings
+) implements ToXContentObject {
 
     private static final String SEARCH_QUERY = "query";
     private static final String TEXTS_FIELD = "docs";

+ 14 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java

@@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost;
 import org.apache.http.client.utils.URIBuilder;
 import org.apache.http.entity.ByteArrayEntity;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
 import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
@@ -28,16 +29,26 @@ public class CohereRerankRequest extends CohereRequest {
     private final CohereAccount account;
     private final String query;
     private final List<String> input;
+    private final Boolean returnDocuments;
+    private final Integer topN;
     private final CohereRerankTaskSettings taskSettings;
     private final String model;
     private final String inferenceEntityId;
 
-    public CohereRerankRequest(String query, List<String> input, CohereRerankModel model) {
+    public CohereRerankRequest(
+        String query,
+        List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        CohereRerankModel model
+    ) {
         Objects.requireNonNull(model);
 
         this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri);
         this.input = Objects.requireNonNull(input);
         this.query = Objects.requireNonNull(query);
+        this.returnDocuments = returnDocuments;
+        this.topN = topN;
         taskSettings = model.getTaskSettings();
         this.model = model.getServiceSettings().modelId();
         inferenceEntityId = model.getInferenceEntityId();
@@ -48,7 +59,8 @@ public class CohereRerankRequest extends CohereRequest {
         HttpPost httpPost = new HttpPost(account.uri());
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new CohereRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
+            Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model))
+                .getBytes(StandardCharsets.UTF_8)
         );
         httpPost.setEntity(byteEntity);
 

+ 26 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.external.request.cohere;
 
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
@@ -15,9 +16,14 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Objects;
 
-public record CohereRerankRequestEntity(String model, String query, List<String> documents, CohereRerankTaskSettings taskSettings)
-    implements
-        ToXContentObject {
+public record CohereRerankRequestEntity(
+    String model,
+    String query,
+    List<String> documents,
+    @Nullable Boolean returnDocuments,
+    @Nullable Integer topN,
+    CohereRerankTaskSettings taskSettings
+) implements ToXContentObject {
 
     private static final String DOCUMENTS_FIELD = "documents";
     private static final String QUERY_FIELD = "query";
@@ -29,8 +35,15 @@ public record CohereRerankRequestEntity(String model, String query, List<String>
         Objects.requireNonNull(taskSettings);
     }
 
-    public CohereRerankRequestEntity(String query, List<String> input, CohereRerankTaskSettings taskSettings, String model) {
-        this(model, query, input, taskSettings);
+    public CohereRerankRequestEntity(
+        String query,
+        List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        CohereRerankTaskSettings taskSettings,
+        String model
+    ) {
+        this(model, query, input, returnDocuments, topN, taskSettings);
     }
 
     @Override
@@ -41,11 +54,17 @@ public record CohereRerankRequestEntity(String model, String query, List<String>
         builder.field(QUERY_FIELD, query);
         builder.field(DOCUMENTS_FIELD, documents);
 
-        if (taskSettings.getDoesReturnDocuments() != null) {
+        // prefer the root level return_documents over task settings
+        if (returnDocuments != null) {
+            builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
+        } else if (taskSettings.getDoesReturnDocuments() != null) {
             builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
         }
 
-        if (taskSettings.getTopNDocumentsOnly() != null) {
+        // prefer the root level top_n over task settings
+        if (topN != null) {
+            builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
+        } else if (taskSettings.getTopNDocumentsOnly() != null) {
             builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
         }
 

+ 21 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java

@@ -11,6 +11,7 @@ import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
 import org.apache.http.entity.ByteArrayEntity;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
@@ -29,10 +30,22 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest {
 
     private final List<String> input;
 
-    public GoogleVertexAiRerankRequest(String query, List<String> input, GoogleVertexAiRerankModel model) {
+    private final Boolean returnDocuments;
+
+    private final Integer topN;
+
+    public GoogleVertexAiRerankRequest(
+        String query,
+        List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        GoogleVertexAiRerankModel model
+    ) {
         this.model = Objects.requireNonNull(model);
         this.query = Objects.requireNonNull(query);
         this.input = Objects.requireNonNull(input);
+        this.returnDocuments = returnDocuments;
+        this.topN = topN;
     }
 
     @Override
@@ -41,7 +54,13 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest {
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
             Strings.toString(
-                new GoogleVertexAiRerankRequestEntity(query, input, model.getServiceSettings().modelId(), model.getTaskSettings().topN())
+                new GoogleVertexAiRerankRequestEntity(
+                    query,
+                    input,
+                    returnDocuments,
+                    topN != null ? topN : model.getTaskSettings().topN(),
+                    model.getServiceSettings().modelId()
+                )
             ).getBytes(StandardCharsets.UTF_8)
         );
 

+ 14 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java

@@ -15,9 +15,13 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Objects;
 
-public record GoogleVertexAiRerankRequestEntity(String query, List<String> inputs, @Nullable String model, @Nullable Integer topN)
-    implements
-        ToXContentObject {
+public record GoogleVertexAiRerankRequestEntity(
+    String query,
+    List<String> inputs,
+    @Nullable Boolean returnDocuments,
+    @Nullable Integer topN,
+    @Nullable String model
+) implements ToXContentObject {
 
     private static final String MODEL_FIELD = "model";
     private static final String QUERY_FIELD = "query";
@@ -26,6 +30,7 @@ public record GoogleVertexAiRerankRequestEntity(String query, List<String> input
 
     private static final String CONTENT_FIELD = "content";
     private static final String TOP_N_FIELD = "topN";
+    private static final String IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD = "ignoreRecordDetailsInResponse";
 
     public GoogleVertexAiRerankRequestEntity {
         Objects.requireNonNull(query);
@@ -57,10 +62,16 @@ public record GoogleVertexAiRerankRequestEntity(String query, List<String> input
 
         builder.endArray();
 
+        // prefer the root level top_n over task settings
         if (topN != null) {
             builder.field(TOP_N_FIELD, topN);
         }
 
+        if (returnDocuments != null) {
+            // if returnDocuments = true, we do not want to ignore record details
+            builder.field(IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD, returnDocuments == Boolean.TRUE ? Boolean.FALSE : Boolean.TRUE);
+        }
+
         builder.endObject();
 
         return builder;

+ 14 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java

@@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost;
 import org.apache.http.client.utils.URIBuilder;
 import org.apache.http.entity.ByteArrayEntity;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount;
 import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
@@ -28,16 +29,26 @@ public class JinaAIRerankRequest extends JinaAIRequest {
     private final JinaAIAccount account;
     private final String query;
     private final List<String> input;
+    private final Boolean returnDocuments;
+    private final Integer topN;
     private final JinaAIRerankTaskSettings taskSettings;
     private final String model;
     private final String inferenceEntityId;
 
-    public JinaAIRerankRequest(String query, List<String> input, JinaAIRerankModel model) {
+    public JinaAIRerankRequest(
+        String query,
+        List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        JinaAIRerankModel model
+    ) {
         Objects.requireNonNull(model);
 
         this.account = JinaAIAccount.of(model, JinaAIRerankRequest::buildDefaultUri);
         this.input = Objects.requireNonNull(input);
         this.query = Objects.requireNonNull(query);
+        this.returnDocuments = returnDocuments;
+        this.topN = topN;
         taskSettings = model.getTaskSettings();
         this.model = model.getServiceSettings().modelId();
         inferenceEntityId = model.getInferenceEntityId();
@@ -48,7 +59,8 @@ public class JinaAIRerankRequest extends JinaAIRequest {
         HttpPost httpPost = new HttpPost(account.uri());
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new JinaAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
+            Strings.toString(new JinaAIRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model))
+                .getBytes(StandardCharsets.UTF_8)
         );
         httpPost.setEntity(byteEntity);
 

+ 27 - 9
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.external.request.jinaai;
 
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
@@ -15,9 +16,14 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Objects;
 
-public record JinaAIRerankRequestEntity(String model, String query, List<String> documents, JinaAIRerankTaskSettings taskSettings)
-    implements
-        ToXContentObject {
+public record JinaAIRerankRequestEntity(
+    String model,
+    String query,
+    List<String> documents,
+    @Nullable Boolean returnDocuments,
+    @Nullable Integer topN,
+    JinaAIRerankTaskSettings taskSettings
+) implements ToXContentObject {
 
     private static final String DOCUMENTS_FIELD = "documents";
     private static final String QUERY_FIELD = "query";
@@ -30,8 +36,15 @@ public record JinaAIRerankRequestEntity(String model, String query, List<String>
         Objects.requireNonNull(taskSettings);
     }
 
-    public JinaAIRerankRequestEntity(String query, List<String> input, JinaAIRerankTaskSettings taskSettings, String model) {
-        this(model, query, input, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS);
+    public JinaAIRerankRequestEntity(
+        String query,
+        List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        JinaAIRerankTaskSettings taskSettings,
+        String model
+    ) {
+        this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS);
     }
 
     @Override
@@ -42,13 +55,18 @@ public record JinaAIRerankRequestEntity(String model, String query, List<String>
         builder.field(QUERY_FIELD, query);
         builder.field(DOCUMENTS_FIELD, documents);
 
-        if (taskSettings.getTopNDocumentsOnly() != null) {
+        // prefer the root level top_n over task settings
+        if (topN != null) {
+            builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
+        } else if (taskSettings.getTopNDocumentsOnly() != null) {
             builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
         }
 
-        var return_documents = taskSettings.getDoesReturnDocuments();
-        if (return_documents != null) {
-            builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, return_documents);
+        // prefer the root level return_documents over task settings
+        if (returnDocuments != null) {
+            builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
+        } else if (taskSettings.getDoesReturnDocuments() != null) {
+            builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
         }
 
         builder.endObject();

+ 22 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.external.request.voyageai;
 import org.apache.http.client.methods.HttpPost;
 import org.apache.http.entity.ByteArrayEntity;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
 import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
@@ -23,13 +24,23 @@ public class VoyageAIRerankRequest extends VoyageAIRequest {
 
     private final String query;
     private final List<String> input;
+    private final Boolean returnDocuments;
+    private final Integer topN;
     private final VoyageAIRerankModel model;
 
-    public VoyageAIRerankRequest(String query, List<String> input, VoyageAIRerankModel model) {
+    public VoyageAIRerankRequest(
+        String query,
+        List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        VoyageAIRerankModel model
+    ) {
         this.model = Objects.requireNonNull(model);
 
         this.input = Objects.requireNonNull(input);
         this.query = Objects.requireNonNull(query);
+        this.returnDocuments = returnDocuments;
+        this.topN = topN;
     }
 
     @Override
@@ -37,8 +48,16 @@ public class VoyageAIRerankRequest extends VoyageAIRequest {
         HttpPost httpPost = new HttpPost(model.uri());
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new VoyageAIRerankRequestEntity(query, input, model.getTaskSettings(), model.getServiceSettings().modelId()))
-                .getBytes(StandardCharsets.UTF_8)
+            Strings.toString(
+                new VoyageAIRerankRequestEntity(
+                    query,
+                    input,
+                    returnDocuments,
+                    topN,
+                    model.getTaskSettings(),
+                    model.getServiceSettings().modelId()
+                )
+            ).getBytes(StandardCharsets.UTF_8)
         );
         httpPost.setEntity(byteEntity);
 

+ 26 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.external.request.voyageai;
 
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
@@ -15,15 +16,19 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Objects;
 
-public record VoyageAIRerankRequestEntity(String model, String query, List<String> documents, VoyageAIRerankTaskSettings taskSettings)
-    implements
-        ToXContentObject {
+public record VoyageAIRerankRequestEntity(
+    String model,
+    String query,
+    List<String> documents,
+    @Nullable Boolean returnDocuments,
+    @Nullable Integer topN,
+    VoyageAIRerankTaskSettings taskSettings
+) implements ToXContentObject {
 
     private static final String DOCUMENTS_FIELD = "documents";
     private static final String QUERY_FIELD = "query";
     private static final String MODEL_FIELD = "model";
     public static final String TRUNCATION_FIELD = "truncation";
-    public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
 
     public VoyageAIRerankRequestEntity {
         Objects.requireNonNull(query);
@@ -32,8 +37,15 @@ public record VoyageAIRerankRequestEntity(String model, String query, List<Strin
         Objects.requireNonNull(taskSettings);
     }
 
-    public VoyageAIRerankRequestEntity(String query, List<String> input, VoyageAIRerankTaskSettings taskSettings, String model) {
-        this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS);
+    public VoyageAIRerankRequestEntity(
+        String query,
+        List<String> input,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
+        VoyageAIRerankTaskSettings taskSettings,
+        String model
+    ) {
+        this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS);
     }
 
     @Override
@@ -44,11 +56,17 @@ public record VoyageAIRerankRequestEntity(String model, String query, List<Strin
         builder.field(QUERY_FIELD, query);
         builder.field(DOCUMENTS_FIELD, documents);
 
-        if (taskSettings.getDoesReturnDocuments() != null) {
+        // prefer the root level return_documents over task settings
+        if (returnDocuments != null) {
+            builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
+        } else if (taskSettings.getDoesReturnDocuments() != null) {
             builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
         }
 
-        if (taskSettings.getTopKDocumentsOnly() != null) {
+        // prefer the root level top_n over task settings
+        if (topN != null) {
+            builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, topN);
+        } else if (taskSettings.getTopKDocumentsOnly() != null) {
             builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly());
         }
 

+ 0 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java

@@ -103,10 +103,6 @@ public class GoogleVertexAiRerankResponseEntity {
         return parseList(parser, (listParser, index) -> {
             var parsedRankedDoc = RankedDoc.parse(parser);
 
-            if (parsedRankedDoc.content == null) {
-                throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.CONTENT.getPreferredName()));
-            }
-
             if (parsedRankedDoc.score == null) {
                 throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
             }

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java

@@ -232,6 +232,8 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
                 TaskType.ANY,
                 inferenceId,
                 null,
+                null,
+                null,
                 List.of(query),
                 Map.of(),
                 InputType.INTERNAL_SEARCH,

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

@@ -153,6 +153,8 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
             TaskType.RERANK,
             inferenceId,
             inferenceText,
+            null,
+            null,
             docFeatures,
             Map.of(),
             InputType.INTERNAL_SEARCH,

+ 15 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

@@ -60,6 +60,8 @@ public abstract class SenderService implements InferenceService {
     public void infer(
         Model model,
         @Nullable String query,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
         List<String> input,
         boolean stream,
         Map<String, Object> taskSettings,
@@ -68,7 +70,7 @@ public abstract class SenderService implements InferenceService {
         ActionListener<InferenceServiceResults> listener
     ) {
         init();
-        var inferenceInput = createInput(this, model, input, inputType, query, stream);
+        var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
         doInfer(model, inferenceInput, taskSettings, timeout, listener);
     }
 
@@ -78,11 +80,20 @@ public abstract class SenderService implements InferenceService {
         List<String> input,
         InputType inputType,
         @Nullable String query,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
         boolean stream
     ) {
         return switch (model.getTaskType()) {
             case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
-            case RERANK -> new QueryAndDocsInputs(query, input, stream);
+            case RERANK -> {
+                ValidationException validationException = new ValidationException();
+                service.validateRerankParameters(returnDocuments, topN, validationException);
+                if (validationException.validationErrors().isEmpty() == false) {
+                    throw validationException;
+                }
+                yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
+            }
             case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
                 ValidationException validationException = new ValidationException();
                 service.validateInputType(inputType, model, validationException);
@@ -141,6 +152,8 @@ public abstract class SenderService implements InferenceService {
 
     protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException);
 
+    protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {}
+
     protected abstract void doUnifiedCompletionInfer(
         Model model,
         UnifiedChatInput inputs,

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -735,6 +735,8 @@ public final class ServiceUtils {
         service.infer(
             model,
             null,
+            null,
+            null,
             List.of(TEST_EMBEDDING_INPUT),
             false,
             Map.of(),

+ 19 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
@@ -300,6 +301,24 @@ public class AlibabaCloudSearchService extends SenderService {
         ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
     }
 
+    @Override
+    protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
+        if (returnDocuments != null) {
+            validationException.addValidationError(
+                Strings.format(
+                    "Invalid return_documents [%s]. The return_documents option is not supported by this service",
+                    returnDocuments
+                )
+            );
+        }
+
+        if (topN != null) {
+            validationException.addValidationError(
+                Strings.format("Invalid top_n [%s]. The top_n option is not supported by this service", topN)
+            );
+        }
+    }
+
     @Override
     protected void doChunkedInfer(
         Model model,

+ 14 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -620,6 +620,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
     public void infer(
         Model model,
         @Nullable String query,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
         List<String> input,
         boolean stream,
         Map<String, Object> taskSettings,
@@ -632,7 +634,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
             if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
                 inferTextEmbedding(esModel, input, inputType, timeout, listener);
             } else if (TaskType.RERANK.equals(taskType)) {
-                inferRerank(esModel, query, input, inputType, timeout, taskSettings, listener);
+                inferRerank(esModel, query, input, returnDocuments, topN, inputType, timeout, taskSettings, listener);
             } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
                 inferSparseEmbedding(esModel, input, inputType, timeout, listener);
             } else {
@@ -693,6 +695,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         ElasticsearchInternalModel model,
         String query,
         List<String> inputs,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer topN,
         InputType inputType,
         TimeValue timeout,
         Map<String, Object> requestTaskSettings,
@@ -701,7 +705,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
 
         var returnDocs = Boolean.TRUE;
-        if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
+        if (returnDocuments != null) {
+            returnDocs = returnDocuments;
+        } else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
             var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
             returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
         }
@@ -709,7 +715,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
 
         ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
-            (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier))
+            (l, inferenceResult) -> l.onResponse(
+                textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN)
+            )
         );
 
         var maybeDeployListener = mlResultsListener.delegateResponse(
@@ -824,7 +832,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
 
     private RankedDocsResults textSimilarityResultsToRankedDocs(
         List<? extends InferenceResults> results,
-        Function<Integer, String> inputSupplier
+        Function<Integer, String> inputSupplier,
+        @Nullable Integer topN
     ) {
         List<RankedDocsResults.RankedDoc> rankings = new ArrayList<>(results.size());
         for (int i = 0; i < results.size(); i++) {
@@ -851,7 +860,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         }
 
         Collections.sort(rankings);
-        return new RankedDocsResults(rankings);
+        return new RankedDocsResults(topN != null ? rankings.subList(0, topN) : rankings);
     }
 
     public List<DefaultConfigId> defaultConfigIds() {

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java

@@ -30,6 +30,8 @@ public class SimpleServiceIntegrationValidator implements ServiceIntegrationVali
         service.infer(
             model,
             model.getTaskType().equals(TaskType.RERANK) ? QUERY : null,
+            null,
+            null,
             TEST_INPUT,
             false,
             Map.of(),

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

@@ -423,9 +423,9 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
         when(service.canStream(any())).thenReturn(stream);
         when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks);
         doAnswer(ans -> {
-            listenerAction.accept(ans.getArgument(7));
+            listenerAction.accept(ans.getArgument(9));
             return null;
-        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
         doAnswer(ans -> {
             listenerAction.accept(ans.getArgument(3));
             return null;

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java

@@ -23,7 +23,7 @@ public class InferenceInputsTests extends ESTestCase {
         var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null);
         assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class));
         assertThat(
-            new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class),
+            new QueryAndDocsInputs("hello", List.of(), Boolean.TRUE, 33, false).castTo(QueryAndDocsInputs.class),
             Matchers.instanceOf(QueryAndDocsInputs.class)
         );
     }

+ 7 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java

@@ -22,7 +22,13 @@ import static org.hamcrest.CoreMatchers.is;
 
 public class AlibabaCloudSearchRerankRequestEntityTests extends ESTestCase {
     public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
-        var entity = new AlibabaCloudSearchRerankRequestEntity("query", List.of("abc"), new AlibabaCloudSearchRerankTaskSettings());
+        var entity = new AlibabaCloudSearchRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            Boolean.TRUE,
+            22,
+            new AlibabaCloudSearchRerankTaskSettings()
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);

+ 95 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class CohereRerankRequestEntityTests extends ESTestCase {
+    public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
+        var entity = new CohereRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            Boolean.TRUE,
+            22,
+            new CohereRerankTaskSettings(null, null, 3),
+            "model"
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        MatcherAssert.assertThat(xContentResult, is("""
+            {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}"""));
+    }
+
+    public void testXContent_WritesMinimalFields() throws IOException {
+        var entity = new CohereRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            null,
+            null,
+            new CohereRerankTaskSettings(null, null, null),
+            "model"
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        MatcherAssert.assertThat(xContentResult, is("""
+            {"model":"model","query":"query","documents":["abc"]}"""));
+    }
+
+    public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException {
+        var entity = new CohereRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            Boolean.FALSE,
+            99,
+            new CohereRerankTaskSettings(33, Boolean.TRUE, null),
+            "model"
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        MatcherAssert.assertThat(xContentResult, is("""
+            {"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}"""));
+    }
+
+    public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException {
+        var entity = new CohereRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            null,
+            null,
+            new CohereRerankTaskSettings(33, Boolean.TRUE, null),
+            "model"
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        MatcherAssert.assertThat(xContentResult, is("""
+            {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}"""));
+    }
+}

+ 12 - 11
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java

@@ -20,8 +20,8 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi
 import static org.hamcrest.MatcherAssert.assertThat;
 
 public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
-    public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException {
-        var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), "model", 8);
+    public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
+        var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), Boolean.TRUE, 10, "model");
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -37,13 +37,14 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
                         "content": "abc"
                     }
                 ],
-                "topN": 8
+                "topN": 10,
+                "ignoreRecordDetailsInResponse": false
             }
             """));
     }
 
-    public void testXContent_SingleRequest_DoesNotWriteModelAndTopNIfNull() throws IOException {
-        var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null);
+    public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
+        var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null, null);
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -62,8 +63,8 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
             """));
     }
 
-    public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException {
-        var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), "model", 8);
+    public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
+        var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), Boolean.FALSE, 12, "model");
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -83,13 +84,14 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
                         "content": "def"
                     }
                 ],
-                "topN": 8
+                "topN": 12,
+                "ignoreRecordDetailsInResponse": true
             }
             """));
     }
 
-    public void testXContent_MultipleRequests_DoesNotWriteModelAndTopNIfNull() throws IOException {
-        var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null);
+    public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException {
+        var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null, null);
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -111,5 +113,4 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
             }
             """));
     }
-
 }

+ 68 - 11
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java

@@ -29,11 +29,11 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
 
     private static final String AUTH_HEADER_VALUE = "foo";
 
-    public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
+    public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
         var input = "input";
         var query = "query";
 
-        var request = createRequest(query, input, null, null);
+        var request = createRequest(query, input, null, null, null, null);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -53,8 +53,9 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
         var input = "input";
         var query = "query";
         var topN = 1;
+        var taskSettingsTopN = 3;
 
-        var request = createRequest(query, input, null, topN);
+        var request = createRequest(query, input, null, topN, null, taskSettingsTopN);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -71,12 +72,55 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
         assertThat(requestMap.get("topN"), is(topN));
     }
 
+    public void testCreateRequest_UsesTaskSettingsTopNWhenRootLevelIsNull() throws IOException {
+        var input = "input";
+        var query = "query";
+        var topN = 1;
+
+        var request = createRequest(query, input, null, null, null, topN);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+
+        assertThat(requestMap, aMapWithSize(3));
+        assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input))));
+        assertThat(requestMap.get("query"), is(query));
+        assertThat(requestMap.get("topN"), is(topN));
+    }
+
+    public void testCreateRequest_WithReturnDocumentsSet() throws IOException {
+        var input = "input";
+        var query = "query";
+
+        var request = createRequest(query, input, null, null, Boolean.TRUE, null);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+
+        assertThat(requestMap, aMapWithSize(3));
+        assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input))));
+        assertThat(requestMap.get("query"), is(query));
+        assertThat(requestMap.get("ignoreRecordDetailsInResponse"), is(Boolean.FALSE));
+    }
+
     public void testCreateRequest_WithModelSet() throws IOException {
         var input = "input";
         var query = "query";
         var modelId = "model";
 
-        var request = createRequest(query, input, modelId, null);
+        var request = createRequest(query, input, modelId, null, null, null);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -94,24 +138,37 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
     }
 
     public void testTruncate_DoesNotTruncate() {
-        var request = createRequest("query", "input", null, null);
+        var request = createRequest("query", "input", null, null, null, null);
         var truncatedRequest = request.truncate();
 
         assertThat(truncatedRequest, sameInstance(request));
     }
 
-    private static GoogleVertexAiRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) {
-        var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, topN);
-
-        return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel);
+    private static GoogleVertexAiRerankRequest createRequest(
+        String query,
+        String input,
+        @Nullable String modelId,
+        @Nullable Integer topN,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer taskSettingsTopN
+    ) {
+        var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, taskSettingsTopN);
+
+        return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel, topN, returnDocuments);
     }
 
     /**
      * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest}
      */
     private static class GoogleVertexAiRerankWithoutAuthRequest extends GoogleVertexAiRerankRequest {
-        GoogleVertexAiRerankWithoutAuthRequest(String query, List<String> input, GoogleVertexAiRerankModel model) {
-            super(query, input, model);
+        GoogleVertexAiRerankWithoutAuthRequest(
+            String query,
+            List<String> input,
+            GoogleVertexAiRerankModel model,
+            @Nullable Integer topN,
+            @Nullable Boolean returnDocuments
+        ) {
+            super(query, input, returnDocuments, topN, model);
         }
 
         @Override

+ 42 - 39
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java

@@ -21,8 +21,15 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi
 import static org.hamcrest.MatcherAssert.assertThat;
 
 public class JinaAIRerankRequestEntityTests extends ESTestCase {
-    public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException {
-        var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, null), "model");
+    public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
+        var entity = new JinaAIRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            Boolean.TRUE,
+            12,
+            new JinaAIRerankTaskSettings(8, Boolean.FALSE),
+            "model"
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -35,13 +42,14 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
                 "documents": [
                     "abc"
                 ],
-                "top_n": 8
+                "top_n": 12,
+                "return_documents": true
             }
             """));
     }
 
-    public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsTrue() throws IOException {
-        var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, true), "model");
+    public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
+        var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, null, new JinaAIRerankTaskSettings(null, null), "model");
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -53,15 +61,20 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
                 "query": "query",
                 "documents": [
                     "abc"
-                ],
-                "top_n": 8,
-                "return_documents": true
+                ]
             }
             """));
     }
 
-    public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsFalse() throws IOException {
-        var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, false), "model");
+    public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
+        var entity = new JinaAIRerankRequestEntity(
+            "query",
+            List.of("abc", "def"),
+            Boolean.FALSE,
+            12,
+            new JinaAIRerankTaskSettings(8, Boolean.TRUE),
+            "model"
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -72,16 +85,17 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
                 "model": "model",
                 "query": "query",
                 "documents": [
-                    "abc"
+                    "abc",
+                    "def"
                 ],
-                "top_n": 8,
+                "top_n": 12,
                 "return_documents": false
             }
             """));
     }
 
-    public void testXContent_SingleRequest_DoesNotWriteTopNIfNull() throws IOException {
-        var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, "model");
+    public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException {
+        var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model");
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -92,14 +106,22 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
                 "model": "model",
                 "query": "query",
                 "documents": [
-                    "abc"
+                   "abc",
+                   "def"
                 ]
             }
             """));
     }
 
-    public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException {
-        var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), new JinaAIRerankTaskSettings(8, null), "model");
+    public void testXContent_SingleRequest_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException {
+        var entity = new JinaAIRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            null,
+            null,
+            new JinaAIRerankTaskSettings(8, Boolean.FALSE),
+            "model"
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -110,29 +132,10 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
                 "model": "model",
                 "query": "query",
                 "documents": [
-                    "abc",
-                    "def"
+                    "abc"
                 ],
-                "top_n": 8
-            }
-            """));
-    }
-
-    public void testXContent_MultipleRequests_DoesNotWriteTopNIfNull() throws IOException {
-        var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, "model");
-
-        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
-        entity.toXContent(builder, null);
-        String xContentResult = Strings.toString(builder);
-
-        assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
-            {
-                "model": "model",
-                "query": "query",
-                "documents": [
-                   "abc",
-                   "def"
-                ]
+                "top_n": 8,
+                "return_documents": false
             }
             """));
     }

+ 19 - 10
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java

@@ -27,12 +27,12 @@ public class JinaAIRerankRequestTests extends ESTestCase {
 
     private static final String API_KEY = "foo";
 
-    public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
+    public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
         var input = "input";
         var query = "query";
         var modelId = "model";
 
-        var request = createRequest(query, input, modelId, null);
+        var request = createRequest(query, input, modelId, null, null, null);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -49,13 +49,14 @@ public class JinaAIRerankRequestTests extends ESTestCase {
         assertThat(requestMap.get("model"), is(modelId));
     }
 
-    public void testCreateRequest_WithTopNSet() throws IOException {
+    public void testCreateRequest_WithAllFieldsSet() throws IOException {
         var input = "input";
         var query = "query";
         var topN = 1;
+        var taskSettingsTopN = 2;
         var modelId = "model";
 
-        var request = createRequest(query, input, modelId, topN);
+        var request = createRequest(query, input, modelId, topN, Boolean.FALSE, taskSettingsTopN);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -66,10 +67,11 @@ public class JinaAIRerankRequestTests extends ESTestCase {
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
 
-        assertThat(requestMap, aMapWithSize(4));
+        assertThat(requestMap, aMapWithSize(5));
         assertThat(requestMap.get("documents"), is(List.of(input)));
         assertThat(requestMap.get("query"), is(query));
         assertThat(requestMap.get("top_n"), is(topN));
+        assertThat(requestMap.get("return_documents"), is(Boolean.FALSE));
         assertThat(requestMap.get("model"), is(modelId));
     }
 
@@ -78,7 +80,7 @@ public class JinaAIRerankRequestTests extends ESTestCase {
         var query = "query";
         var modelId = "model";
 
-        var request = createRequest(query, input, modelId, null);
+        var request = createRequest(query, input, modelId, null, null, null);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -96,15 +98,22 @@ public class JinaAIRerankRequestTests extends ESTestCase {
     }
 
     public void testTruncate_DoesNotTruncate() {
-        var request = createRequest("query", "input", "null", null);
+        var request = createRequest("query", "input", "null", null, null, null);
         var truncatedRequest = request.truncate();
 
         assertThat(truncatedRequest, sameInstance(request));
     }
 
-    private static JinaAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) {
-        var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, topN);
-        return new JinaAIRerankRequest(query, List.of(input), rerankModel);
+    private static JinaAIRerankRequest createRequest(
+        String query,
+        String input,
+        @Nullable String modelId,
+        @Nullable Integer topN,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer taskSettingsTopN
+    ) {
+        var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopN);
+        return new JinaAIRerankRequest(query, List.of(input), returnDocuments, topN, rerankModel);
 
     }
 }

+ 65 - 45
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java

@@ -20,8 +20,15 @@ import java.util.List;
 import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
 
 public class VoyageAIRerankRequestEntityTests extends ESTestCase {
-    public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException {
-        var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model");
+    public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
+        var entity = new VoyageAIRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            Boolean.TRUE,
+            12,
+            new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null),
+            "model"
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -34,13 +41,21 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
                 "documents": [
                     "abc"
                 ],
-                "top_k": 8
+                "return_documents": true,
+                "top_k": 12
             }
             """));
     }
 
-    public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException {
-        var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model");
+    public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
+        var entity = new VoyageAIRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            null,
+            null,
+            new VoyageAIRerankTaskSettings(null, true, null),
+            "model"
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -53,14 +68,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
                 "documents": [
                     "abc"
                 ],
-                "return_documents": true,
-                "top_k": 8
+                "return_documents": true
             }
             """));
     }
 
-    public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException {
-        var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model");
+    public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException {
+        var entity = new VoyageAIRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            null,
+            null,
+            new VoyageAIRerankTaskSettings(8, false, true),
+            "model"
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -74,13 +95,21 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
                     "abc"
                 ],
                 "return_documents": false,
-                "top_k": 8
+                "top_k": 8,
+                "truncation": true
             }
             """));
     }
 
-    public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException {
-        var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model");
+    public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException {
+        var entity = new VoyageAIRerankRequestEntity(
+            "query",
+            List.of("abc"),
+            null,
+            null,
+            new VoyageAIRerankTaskSettings(8, false, false),
+            "model"
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -95,13 +124,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
                 ],
                 "return_documents": false,
                 "top_k": 8,
-                "truncation": true
+                "truncation": false
             }
             """));
     }
 
-    public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException {
-        var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model");
+    public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
+        var entity = new VoyageAIRerankRequestEntity(
+            "query",
+            List.of("abc", "def"),
+            Boolean.FALSE,
+            11,
+            new VoyageAIRerankTaskSettings(8, null, null),
+            "model"
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -112,17 +148,17 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
                 "model": "model",
                 "query": "query",
                 "documents": [
-                    "abc"
+                    "abc",
+                    "def"
                 ],
                 "return_documents": false,
-                "top_k": 8,
-                "truncation": false
+                "top_k": 11
             }
             """));
     }
 
-    public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException {
-        var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model");
+    public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException {
+        var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model");
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -133,17 +169,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
                 "model": "model",
                 "query": "query",
                 "documents": [
-                    "abc"
+                   "abc",
+                   "def"
                 ]
             }
             """));
     }
 
-    public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException {
+    public void testXContent_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException {
         var entity = new VoyageAIRerankRequestEntity(
             "query",
-            List.of("abc", "def"),
-            new VoyageAIRerankTaskSettings(8, null, null),
+            List.of("abc"),
+            null,
+            null,
+            new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null),
             "model"
         );
 
@@ -156,31 +195,12 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
                 "model": "model",
                 "query": "query",
                 "documents": [
-                    "abc",
-                    "def"
+                    "abc"
                 ],
+                "return_documents": false,
                 "top_k": 8
             }
             """));
     }
 
-    public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException {
-        var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model");
-
-        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
-        entity.toXContent(builder, null);
-        String xContentResult = Strings.toString(builder);
-
-        assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
-            {
-                "model": "model",
-                "query": "query",
-                "documents": [
-                   "abc",
-                   "def"
-                ]
-            }
-            """));
-    }
-
 }

+ 19 - 10
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java

@@ -27,12 +27,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
 
     private static final String API_KEY = "foo";
 
-    public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
+    public void testCreateRequest_WithMinimalFields() throws IOException {
         var input = "input";
         var query = "query";
         var modelId = "model";
 
-        var request = createRequest(query, input, modelId, null);
+        var request = createRequest(query, input, modelId, null, null, null);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -49,13 +49,14 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
         assertThat(requestMap.get("model"), is(modelId));
     }
 
-    public void testCreateRequest_WithTopNSet() throws IOException {
+    public void testCreateRequest_WithAllFieldsDefined() throws IOException {
         var input = "input";
         var query = "query";
         var topK = 1;
+        var taskSettingsTopK = 2;
         var modelId = "model";
 
-        var request = createRequest(query, input, modelId, topK);
+        var request = createRequest(query, input, modelId, topK, Boolean.FALSE, taskSettingsTopK);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -66,11 +67,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
 
-        assertThat(requestMap, aMapWithSize(4));
+        assertThat(requestMap, aMapWithSize(5));
         assertThat(requestMap.get("documents"), is(List.of(input)));
         assertThat(requestMap.get("query"), is(query));
         assertThat(requestMap.get("top_k"), is(topK));
         assertThat(requestMap.get("model"), is(modelId));
+        assertThat(requestMap.get("return_documents"), is(Boolean.FALSE));
     }
 
     public void testCreateRequest_WithModelSet() throws IOException {
@@ -78,7 +80,7 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
         var query = "query";
         var modelId = "model";
 
-        var request = createRequest(query, input, modelId, null);
+        var request = createRequest(query, input, modelId, null, null, null);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -96,15 +98,22 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
     }
 
     public void testTruncate_DoesNotTruncate() {
-        var request = createRequest("query", "input", "null", null);
+        var request = createRequest("query", "input", "null", null, null, null);
         var truncatedRequest = request.truncate();
 
         assertThat(truncatedRequest, sameInstance(request));
     }
 
-    private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) {
-        var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK);
-        return new VoyageAIRerankRequest(query, List.of(input), rerankModel);
+    private static VoyageAIRerankRequest createRequest(
+        String query,
+        String input,
+        @Nullable String modelId,
+        @Nullable Integer topK,
+        @Nullable Boolean returnDocuments,
+        @Nullable Integer taskSettingsTopK
+    ) {
+        var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopK);
+        return new VoyageAIRerankRequest(query, List.of(input), returnDocuments, topK, rerankModel);
 
     }
 }

+ 35 - 17
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java

@@ -42,6 +42,26 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
         assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
     }
 
+    public void testFromResponse_CreatesResultsForASingleItem_NoContent() throws IOException {
+        String responseJson = """
+            {
+                 "records": [
+                     {
+                         "id": "2",
+                         "title": "title 2",
+                         "score": 0.97
+                     }
+                ]
+            }
+            """;
+
+        RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null))));
+    }
+
     public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
         String responseJson = """
             {
@@ -72,40 +92,38 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
         );
     }
 
-    public void testFromResponse_FailsWhenRecordsFieldIsNotPresent() {
+    public void testFromResponse_CreatesResultsForMultipleItems_NoContent() throws IOException {
         String responseJson = """
             {
-                 "not_records": [
+                 "records": [
                      {
                          "id": "2",
                          "title": "title 2",
-                         "content": "content 2",
                          "score": 0.97
                      },
                      {
                          "id": "1",
                          "title": "title 1",
-                         "content": "content 1",
                          "score": 0.90
                      }
                 ]
             }
             """;
 
-        var thrownException = expectThrows(
-            IllegalStateException.class,
-            () -> GoogleVertexAiRerankResponseEntity.fromResponse(
-                new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
-            )
+        RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
         );
 
-        assertThat(thrownException.getMessage(), is("Failed to find required field [records] in Google Vertex AI rerank response"));
+        assertThat(
+            parsedResults.getRankedDocs(),
+            is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null), new RankedDocsResults.RankedDoc(1, 0.90F, null)))
+        );
     }
 
-    public void testFromResponse_FailsWhenContentFieldIsNotPresent() {
+    public void testFromResponse_FailsWhenRecordsFieldIsNotPresent() {
         String responseJson = """
             {
-                 "records": [
+                 "not_records": [
                      {
                          "id": "2",
                          "title": "title 2",
@@ -113,10 +131,10 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
                          "score": 0.97
                      },
                      {
-                        "id": "1",
-                        "title": "title 1",
-                        "not_content": "content 1",
-                        "score": 0.97
+                         "id": "1",
+                         "title": "title 1",
+                         "content": "content 1",
+                         "score": 0.90
                      }
                 ]
             }
@@ -129,7 +147,7 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
             )
         );
 
-        assertThat(thrownException.getMessage(), is("Failed to find required field [content] in Google Vertex AI rerank response"));
+        assertThat(thrownException.getMessage(), is("Failed to find required field [records] in Google Vertex AI rerank response"));
     }
 
     public void testFromResponse_FailsWhenScoreFieldIsNotPresent() {

+ 2 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java

@@ -98,6 +98,8 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
                         TaskType.RERANK,
                         this.inferenceId,
                         inferenceText,
+                        null,
+                        null,
                         docFeatures,
                         Map.of("inferenceResultCount", inferenceResultCount),
                         InputType.INTERNAL_SEARCH,

+ 2 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java

@@ -225,6 +225,8 @@ public class TextSimilarityTestPlugin extends Plugin implements ActionPlugin {
                             TaskType.RERANK,
                             inferenceId,
                             inferenceText,
+                            null,
+                            null,
                             docFeatures,
                             Map.of("throwing", true),
                             InputType.INTERNAL_SEARCH,

+ 8 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java

@@ -910,11 +910,11 @@ public class ServiceUtilsTests extends ESTestCase {
         when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
 
         doAnswer(invocation -> {
-            ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
             listener.onResponse(new TextEmbeddingFloatResults(List.of()));
 
             return Void.TYPE;
-        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
 
         PlainActionFuture<Integer> listener = new PlainActionFuture<>();
         getEmbeddingSize(model, service, listener);
@@ -932,11 +932,11 @@ public class ServiceUtilsTests extends ESTestCase {
         when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
 
         doAnswer(invocation -> {
-            ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
             listener.onResponse(new TextEmbeddingByteResults(List.of()));
 
             return Void.TYPE;
-        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
 
         PlainActionFuture<Integer> listener = new PlainActionFuture<>();
         getEmbeddingSize(model, service, listener);
@@ -956,11 +956,11 @@ public class ServiceUtilsTests extends ESTestCase {
         var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults();
 
         doAnswer(invocation -> {
-            ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
             listener.onResponse(textEmbedding);
 
             return Void.TYPE;
-        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
 
         PlainActionFuture<Integer> listener = new PlainActionFuture<>();
         getEmbeddingSize(model, service, listener);
@@ -979,11 +979,11 @@ public class ServiceUtilsTests extends ESTestCase {
         var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults();
 
         doAnswer(invocation -> {
-            ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
             listener.onResponse(textEmbedding);
 
             return Void.TYPE;
-        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
 
         PlainActionFuture<Integer> listener = new PlainActionFuture<>();
         getEmbeddingSize(model, service, listener);

+ 51 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java

@@ -389,6 +389,8 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -431,6 +433,8 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -446,6 +450,53 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
         }
     }
 
+    public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        Map<String, Object> serviceSettingsMap = new HashMap<>();
+        serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id");
+        serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host");
+        serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default");
+        serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536);
+
+        Map<String, Object> taskSettingsMap = new HashMap<>();
+
+        Map<String, Object> secretSettingsMap = new HashMap<>();
+        secretSettingsMap.put("api_key", "secret");
+
+        var model = AlibabaCloudSearchEmbeddingsModelTests.createModel(
+            "service",
+            TaskType.RERANK,
+            serviceSettingsMap,
+            taskSettingsMap,
+            secretSettingsMap
+        );
+        try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            var thrownException = expectThrows(
+                ValidationException.class,
+                () -> service.infer(
+                    model,
+                    "hi",
+                    Boolean.TRUE,
+                    10,
+                    List.of("a"),
+                    false,
+                    new HashMap<>(),
+                    null,
+                    InferenceAction.Request.DEFAULT_TIMEOUT,
+                    listener
+                )
+            );
+            assertThat(
+                thrownException.getMessage(),
+                is(
+                    "Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this "
+                        + "service;2: Invalid top_n [10]. The top_n option is not supported by this service;"
+                )
+            );
+        }
+    }
+
     public void testChunkedInfer_TextEmbeddingChunkingSettingsSet() throws IOException {
         testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings());
     }

+ 10 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

@@ -932,6 +932,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -979,6 +981,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -1029,6 +1033,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
                 service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of("abc"),
                     false,
                     new HashMap<>(),
@@ -1071,6 +1077,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
                 service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of("abc"),
                     false,
                     new HashMap<>(),
@@ -1414,6 +1422,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),

+ 6 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java

@@ -458,6 +458,8 @@ public class AnthropicServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -513,6 +515,8 @@ public class AnthropicServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("input"),
                 false,
                 new HashMap<>(),
@@ -571,6 +575,8 @@ public class AnthropicServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 true,
                 new HashMap<>(),

+ 10 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

@@ -1096,6 +1096,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -1134,6 +1136,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                 () -> service.infer(
                     mockModel,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -1296,6 +1300,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1347,6 +1353,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1403,6 +1411,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 true,
                 new HashMap<>(),

+ 8 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

@@ -766,6 +766,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -822,6 +824,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1286,6 +1290,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1453,6 +1459,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 true,
                 new HashMap<>(),

+ 14 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

@@ -788,6 +788,8 @@ public class CohereServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -856,6 +858,8 @@ public class CohereServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1147,6 +1151,8 @@ public class CohereServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1207,6 +1213,8 @@ public class CohereServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1281,6 +1289,8 @@ public class CohereServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null),
@@ -1353,6 +1363,8 @@ public class CohereServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1629,6 +1641,8 @@ public class CohereServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 true,
                 new HashMap<>(),

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java

@@ -232,7 +232,7 @@ public class DeepSeekServiceTests extends ESTestCase {
         try (var service = createService()) {
             var model = createModel(service, TaskType.COMPLETION);
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
+            service.infer(model, null, null, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
             var result = listener.actionGet(TIMEOUT);
             assertThat(result, isA(ChatCompletionResults.class));
             var completionResults = (ChatCompletionResults) result;
@@ -255,7 +255,7 @@ public class DeepSeekServiceTests extends ESTestCase {
         try (var service = createService()) {
             var model = createModel(service, TaskType.COMPLETION);
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
+            service.infer(model, null, null, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
             InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent("""
                 {"completion":[{"delta":"hello, world"}]}""");
         }

+ 10 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

@@ -368,6 +368,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -404,6 +406,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -443,6 +447,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -494,6 +500,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("input text"),
                 false,
                 new HashMap<>(),
@@ -551,6 +559,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                 service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of("input text"),
                     false,
                     new HashMap<>(),

+ 10 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java

@@ -662,6 +662,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -700,6 +702,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -775,6 +779,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("input"),
                 false,
                 new HashMap<>(),
@@ -832,6 +838,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of(input),
                 false,
                 new HashMap<>(),
@@ -1005,6 +1013,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),

+ 2 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java

@@ -65,6 +65,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),

+ 6 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

@@ -556,6 +556,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -593,6 +595,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of("abc"),
                     false,
                     new HashMap<>(),
@@ -627,6 +631,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),

+ 8 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java

@@ -602,6 +602,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -641,6 +643,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -697,6 +701,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of(input),
                 false,
                 new HashMap<>(),
@@ -840,6 +846,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),

+ 34 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java

@@ -782,6 +782,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -1044,6 +1046,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1076,6 +1080,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2"),
                 false,
                 new HashMap<>(),
@@ -1132,6 +1138,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1201,6 +1209,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1254,6 +1264,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1320,7 +1332,18 @@ public class JinaAIServiceTests extends ESTestCase {
                 JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            service.infer(
+                model,
+                null,
+                null,
+                null,
+                List.of("abc"),
+                false,
+                new HashMap<>(),
+                null,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -1371,6 +1394,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2", "candidate3"),
                 false,
                 new HashMap<>(),
@@ -1454,6 +1479,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2", "candidate3", "candidate4"),
                 false,
                 new HashMap<>(),
@@ -1549,6 +1576,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2", "candidate3"),
                 false,
                 new HashMap<>(),
@@ -1630,6 +1659,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2", "candidate3", "candidate4"),
                 false,
                 new HashMap<>(),
@@ -1724,6 +1755,8 @@ public class JinaAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),

+ 6 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java

@@ -586,6 +586,8 @@ public class MistralServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -625,6 +627,8 @@ public class MistralServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -781,6 +785,8 @@ public class MistralServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),

+ 14 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

@@ -852,6 +852,8 @@ public class OpenAiServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -890,6 +892,8 @@ public class OpenAiServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -925,6 +929,8 @@ public class OpenAiServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -964,6 +970,8 @@ public class OpenAiServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -1024,6 +1032,8 @@ public class OpenAiServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1263,6 +1273,8 @@ public class OpenAiServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 true,
                 new HashMap<>(),
@@ -1794,6 +1806,8 @@ public class OpenAiServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),

+ 7 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java

@@ -63,6 +63,8 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
             .infer(
                 eq(mockModel),
                 eq(null),
+                eq(null),
+                eq(null),
                 eq(TEST_INPUT),
                 eq(false),
                 eq(Map.of()),
@@ -97,13 +99,15 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
 
     private void mockSuccessfulCallToService(String query, InferenceServiceResults result) {
         doAnswer(ans -> {
-            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(7);
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(9);
             responseListener.onResponse(result);
             return null;
         }).when(mockInferenceService)
             .infer(
                 eq(mockModel),
                 eq(query),
+                eq(null),
+                eq(null),
                 eq(TEST_INPUT),
                 eq(false),
                 eq(Map.of()),
@@ -120,6 +124,8 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
         verify(mockInferenceService).infer(
             eq(mockModel),
             eq(withQuery ? TEST_QUERY : null),
+            eq(null),
+            eq(null),
             eq(TEST_INPUT),
             eq(false),
             eq(Map.of()),

+ 34 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java

@@ -722,6 +722,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 mockModel,
                 null,
+                null,
+                null,
                 List.of(""),
                 false,
                 new HashMap<>(),
@@ -768,6 +770,8 @@ public class VoyageAIServiceTests extends ESTestCase {
                 () -> service.infer(
                     model,
                     null,
+                    null,
+                    null,
                     List.of(""),
                     false,
                     new HashMap<>(),
@@ -1017,6 +1021,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1049,6 +1055,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2"),
                 false,
                 new HashMap<>(),
@@ -1103,6 +1111,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1183,6 +1193,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),
@@ -1260,7 +1272,18 @@ public class VoyageAIServiceTests extends ESTestCase {
                 (SimilarityMeasure) null
             );
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+            service.infer(
+                model,
+                null,
+                null,
+                null,
+                List.of("abc"),
+                false,
+                new HashMap<>(),
+                null,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
 
             var result = listener.actionGet(TIMEOUT);
 
@@ -1315,6 +1338,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2", "candidate3"),
                 false,
                 new HashMap<>(),
@@ -1401,6 +1426,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2", "candidate3", "candidate4"),
                 false,
                 new HashMap<>(),
@@ -1493,6 +1520,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2", "candidate3"),
                 false,
                 new HashMap<>(),
@@ -1569,6 +1598,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 "query",
+                null,
+                null,
                 List.of("candidate1", "candidate2", "candidate3", "candidate4"),
                 false,
                 new HashMap<>(),
@@ -1663,6 +1694,8 @@ public class VoyageAIServiceTests extends ESTestCase {
             service.infer(
                 model,
                 null,
+                null,
+                null,
                 List.of("abc"),
                 false,
                 new HashMap<>(),

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java

@@ -123,6 +123,8 @@ public class TransportCoordinatedInferenceAction extends HandledTransportAction<
                 TaskType.ANY,
                 request.getModelId(),
                 null,
+                null,
+                null,
                 request.getInputs(),
                 request.getTaskSettings(),
                 inputType,