Browse Source

[8.19] [Inference API] Add "rerank" task type to "elastic" provider (#126022) #129196

Tim Grein 4 months ago
parent
commit
bac3e16ed9
20 changed files with 1232 additions and 28 deletions
  1. 2 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  2. 15 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
  3. 94 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java
  4. 59 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java
  5. 70 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntity.java
  6. 37 9
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java
  7. 12 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java
  8. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java
  9. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java
  10. 36 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java
  11. 3 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java
  12. 104 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java
  13. 126 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java
  14. 122 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java
  15. 89 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java
  16. 148 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntityTests.java
  17. 131 8
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
  18. 75 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java
  19. 30 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java
  20. 76 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java

+ 2 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -238,6 +238,8 @@ public class TransportVersions {
     public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
     public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);
     public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47);
+    public static final TransportVersion ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19 = def(8_841_0_48);
+
     /*
      * STOP! READ THIS FIRST! No, really,
      *        ____ _____ ___  ____  _        ____  _____    _    ____    _____ _   _ ___ ____    _____ ___ ____  ____ _____ _

+ 15 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -71,6 +71,7 @@ import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingR
 import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
@@ -166,7 +167,7 @@ public class InferenceNamedWriteablesProvider {
         addAnthropicNamedWritables(namedWriteables);
         addAmazonBedrockNamedWriteables(namedWriteables);
         addAwsNamedWriteables(namedWriteables);
-        addEisNamedWriteables(namedWriteables);
+        addElasticNamedWriteables(namedWriteables);
         addAlibabaCloudSearchNamedWriteables(namedWriteables);
         addJinaAINamedWriteables(namedWriteables);
         addVoyageAINamedWriteables(namedWriteables);
@@ -742,7 +743,8 @@ public class InferenceNamedWriteablesProvider {
         );
     }
 
-    private static void addEisNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
+    private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
+        // Sparse Text Embeddings
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(
                 ServiceSettings.class,
@@ -750,6 +752,8 @@ public class InferenceNamedWriteablesProvider {
                 ElasticInferenceServiceSparseEmbeddingsServiceSettings::new
             )
         );
+
+        // Completion
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(
                 ServiceSettings.class,
@@ -757,5 +761,14 @@ public class InferenceNamedWriteablesProvider {
                 ElasticInferenceServiceCompletionServiceSettings::new
             )
         );
+
+        // Rerank
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                ElasticInferenceServiceRerankServiceSettings.NAME,
+                ElasticInferenceServiceRerankServiceSettings::new
+            )
+        );
     }
 }

+ 94 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java

@@ -0,0 +1,94 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.client.methods.HttpRequestBase;
+import org.apache.http.entity.ByteArrayEntity;
+import org.apache.http.message.BasicHeader;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest;
+import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
+import org.elasticsearch.xpack.inference.telemetry.TraceContext;
+import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
+
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Objects;
+
+public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServiceRequest {
+
+    private final String query;
+    private final List<String> documents;
+    private final Integer topN;
+    private final TraceContextHandler traceContextHandler;
+    private final ElasticInferenceServiceRerankModel model;
+
+    public ElasticInferenceServiceRerankRequest(
+        String query,
+        List<String> documents,
+        Integer topN,
+        ElasticInferenceServiceRerankModel model,
+        TraceContext traceContext,
+        ElasticInferenceServiceRequestMetadata metadata
+    ) {
+        super(metadata);
+        this.query = query;
+        this.documents = documents;
+        this.topN = topN;
+        this.model = Objects.requireNonNull(model);
+        this.traceContextHandler = new TraceContextHandler(traceContext);
+    }
+
+    @Override
+    public HttpRequestBase createHttpRequestBase() {
+        var httpPost = new HttpPost(getURI());
+        var requestEntity = Strings.toString(
+            new ElasticInferenceServiceRerankRequestEntity(query, documents, model.getServiceSettings().modelId(), topN)
+        );
+
+        ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
+        httpPost.setEntity(byteEntity);
+
+        traceContextHandler.propagateTraceContext(httpPost);
+        httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
+
+        return httpPost;
+    }
+
+    public TraceContext getTraceContext() {
+        return traceContextHandler.traceContext();
+    }
+
+    @Override
+    public String getInferenceEntityId() {
+        return model.getInferenceEntityId();
+    }
+
+    @Override
+    public URI getURI() {
+        return model.uri();
+    }
+
+    @Override
+    public Request truncate() {
+        // no truncation
+        return this;
+    }
+
+    @Override
+    public boolean[] getTruncationInfo() {
+        // no truncation
+        return null;
+    }
+}

+ 59 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java

@@ -0,0 +1,59 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+public record ElasticInferenceServiceRerankRequestEntity(
+    String query,
+    List<String> documents,
+    String modelId,
+    @Nullable Integer topNDocumentsOnly
+) implements ToXContentObject {
+
+    private static final String QUERY_FIELD = "query";
+    private static final String MODEL_FIELD = "model";
+    private static final String TOP_N_DOCUMENTS_ONLY_FIELD = "top_n";
+    private static final String DOCUMENTS_FIELD = "documents";
+
+    public ElasticInferenceServiceRerankRequestEntity {
+        Objects.requireNonNull(query);
+        Objects.requireNonNull(documents);
+        Objects.requireNonNull(modelId);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        builder.field(QUERY_FIELD, query);
+
+        builder.field(MODEL_FIELD, modelId);
+
+        if (Objects.nonNull(topNDocumentsOnly)) {
+            builder.field(TOP_N_DOCUMENTS_ONLY_FIELD, topNDocumentsOnly);
+        }
+
+        builder.startArray(DOCUMENTS_FIELD);
+        for (String document : documents) {
+            builder.value(document);
+        }
+
+        builder.endArray();
+
+        builder.endObject();
+
+        return builder;
+    }
+}

+ 70 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntity.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.elastic;
+
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+
+public class ElasticInferenceServiceRerankResponseEntity {
+
+    record RerankResult(List<RerankResultEntry> entries) {
+
+        @SuppressWarnings("unchecked")
+        public static final ConstructingObjectParser<RerankResult, Void> PARSER = new ConstructingObjectParser<>(
+            RerankResult.class.getSimpleName(),
+            true,
+            args -> new RerankResult((List<RerankResultEntry>) args[0])
+        );
+
+        static {
+            PARSER.declareObjectArray(constructorArg(), RerankResultEntry.PARSER::apply, new ParseField("results"));
+        }
+
+        record RerankResultEntry(Integer index, Float relevanceScore) {
+
+            public static final ConstructingObjectParser<RerankResultEntry, Void> PARSER = new ConstructingObjectParser<>(
+                RerankResultEntry.class.getSimpleName(),
+                args -> new RerankResultEntry((Integer) args[0], (Float) args[1])
+            );
+
+            static {
+                PARSER.declareInt(constructorArg(), new ParseField("index"));
+                PARSER.declareFloat(constructorArg(), new ParseField("relevance_score"));
+            }
+
+            public RankedDocsResults.RankedDoc toRankedDoc() {
+                return new RankedDocsResults.RankedDoc(index, relevanceScore, null);
+            }
+        }
+    }
+
+    public static InferenceServiceResults fromResponse(HttpResult response) throws IOException {
+        var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+        try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+            var rerankResult = RerankResult.PARSER.apply(jsonParser, null);
+
+            return new RankedDocsResults(rerankResult.entries.stream().map(RerankResult.RerankResultEntry::toRankedDoc).toList());
+        }
+    }
+
+    private ElasticInferenceServiceRerankResponseEntity() {}
+}

+ 37 - 9
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

@@ -51,6 +51,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
 import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
 import org.elasticsearch.xpack.inference.telemetry.TraceContext;
 
@@ -77,7 +78,11 @@ public class ElasticInferenceService extends SenderService {
     public static final String NAME = "elastic";
     public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
 
-    private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
+    private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
+        TaskType.SPARSE_EMBEDDING,
+        TaskType.CHAT_COMPLETION,
+        TaskType.RERANK
+    );
     private static final String SERVICE_NAME = "Elastic";
 
     // rainbow-sprinkles
@@ -91,7 +96,7 @@ public class ElasticInferenceService extends SenderService {
     /**
      * The task types that the {@link InferenceAction.Request} can accept.
      */
-    private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
+    private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK);
 
     public static String defaultEndpointId(String modelId) {
         return Strings.format(".%s-elastic", modelId);
@@ -161,6 +166,18 @@ public class ElasticInferenceService extends SenderService {
         authorizationHandler.init();
     }
 
+    @Override
+    protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
+        if (returnDocuments != null) {
+            validationException.addValidationError(
+                org.elasticsearch.core.Strings.format(
+                    "Invalid return_documents [%s]. The return_documents option is not supported by this service",
+                    returnDocuments
+                )
+            );
+        }
+    }
+
     /**
      * Only use this in tests.
      *
@@ -333,7 +350,7 @@ public class ElasticInferenceService extends SenderService {
         Map<String, Object> serviceSettings,
         Map<String, Object> taskSettings,
         @Nullable Map<String, Object> secretSettings,
-        ElasticInferenceServiceComponents eisServiceComponents,
+        ElasticInferenceServiceComponents elasticInferenceServiceComponents,
         String failureMessage,
         ConfigurationParseContext context
     ) {
@@ -345,7 +362,7 @@ public class ElasticInferenceService extends SenderService {
                 serviceSettings,
                 taskSettings,
                 secretSettings,
-                eisServiceComponents,
+                elasticInferenceServiceComponents,
                 context
             );
             case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
@@ -355,7 +372,17 @@ public class ElasticInferenceService extends SenderService {
                 serviceSettings,
                 taskSettings,
                 secretSettings,
-                eisServiceComponents,
+                elasticInferenceServiceComponents,
+                context
+            );
+            case RERANK -> new ElasticInferenceServiceRerankModel(
+                inferenceEntityId,
+                taskType,
+                NAME,
+                serviceSettings,
+                taskSettings,
+                secretSettings,
+                elasticInferenceServiceComponents,
                 context
             );
             default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
@@ -451,9 +478,8 @@ public class ElasticInferenceService extends SenderService {
 
                 configurationMap.put(
                     MODEL_ID,
-                    new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
-                        "The name of the model to use for the inference task."
-                    )
+                    new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK))
+                        .setDescription("The name of the model to use for the inference task.")
                         .setLabel("Model ID")
                         .setRequired(true)
                         .setSensitive(false)
@@ -476,7 +502,9 @@ public class ElasticInferenceService extends SenderService {
                 );
 
                 configurationMap.putAll(
-                    RateLimitSettings.toSettingsConfiguration(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))
+                    RateLimitSettings.toSettingsConfiguration(
+                        EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK)
+                    )
                 );
 
                 return new InferenceServiceConfiguration.Builder().setService(NAME)

+ 12 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java

@@ -7,14 +7,15 @@
 
 package org.elasticsearch.xpack.inference.services.elastic;
 
-import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
 
 import java.util.Objects;
 
-public abstract class ElasticInferenceServiceModel extends Model {
+public abstract class ElasticInferenceServiceModel extends RateLimitGroupingModel {
 
     private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings;
 
@@ -35,12 +36,18 @@ public abstract class ElasticInferenceServiceModel extends Model {
     public ElasticInferenceServiceModel(ElasticInferenceServiceModel model, ServiceSettings serviceSettings) {
         super(model, serviceSettings);
 
-        this.rateLimitServiceSettings = model.rateLimitServiceSettings();
+        this.rateLimitServiceSettings = model.rateLimitServiceSettings;
         this.elasticInferenceServiceComponents = model.elasticInferenceServiceComponents();
     }
 
-    public ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings() {
-        return rateLimitServiceSettings;
+    @Override
+    public int rateLimitGroupingHash() {
+        // We only have one model for rerank
+        return Objects.hash(this.getServiceSettings().modelId());
+    }
+
+    public RateLimitSettings rateLimitSettings() {
+        return rateLimitServiceSettings.rateLimitSettings();
     }
 
     public ElasticInferenceServiceComponents elasticInferenceServiceComponents() {

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java

@@ -20,7 +20,7 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM
     private final ElasticInferenceServiceRequestMetadata requestMetadata;
 
     protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) {
-        super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
+        super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitSettings());
         this.requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
     }
 
@@ -32,7 +32,7 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM
         public static RateLimitGrouping of(ElasticInferenceServiceModel model) {
             Objects.requireNonNull(model);
 
-            return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
+            return new RateLimitGrouping(model.rateLimitGroupingHash());
         }
     }
 }

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java

@@ -30,6 +30,7 @@ import java.util.function.Supplier;
 import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
 import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
 
+// TODO: remove and use GenericRequestManager in ElasticInferenceServiceActionCreator
 public class ElasticInferenceServiceSparseEmbeddingsRequestManager extends ElasticInferenceServiceRequestManager {
 
     private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceSparseEmbeddingsRequestManager.class);

+ 36 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java

@@ -7,19 +7,27 @@
 
 package org.elasticsearch.xpack.inference.services.elastic.action;
 
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest;
+import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
 import org.elasticsearch.xpack.inference.telemetry.TraceContext;
 
-import java.util.Locale;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
 import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
+import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext;
 
 public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor {
 
@@ -29,6 +37,11 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
 
     private final TraceContext traceContext;
 
+    static final ResponseHandler RERANK_HANDLER = new ElasticInferenceServiceResponseHandler(
+        "elastic rerank",
+        (request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response)
+    );
+
     public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) {
         this.sender = Objects.requireNonNull(sender);
         this.serviceComponents = Objects.requireNonNull(serviceComponents);
@@ -39,8 +52,29 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
     public ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model) {
         var requestManager = new ElasticInferenceServiceSparseEmbeddingsRequestManager(model, serviceComponents, traceContext);
         var errorMessage = constructFailedToSendRequestMessage(
-            String.format(Locale.ROOT, "%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
+            Strings.format("%s sparse embeddings", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)
+        );
+        return new SenderExecutableAction(sender, requestManager, errorMessage);
+    }
+
+    @Override
+    public ExecutableAction create(ElasticInferenceServiceRerankModel model) {
+        var threadPool = serviceComponents.threadPool();
+        var requestManager = new GenericRequestManager<>(
+            threadPool,
+            model,
+            RERANK_HANDLER,
+            (rerankInput) -> new ElasticInferenceServiceRerankRequest(
+                rerankInput.getQuery(),
+                rerankInput.getChunks(),
+                rerankInput.getTopN(),
+                model,
+                traceContext,
+                extractRequestMetadataFromThreadContext(threadPool.getThreadContext())
+            ),
+            QueryAndDocsInputs.class
         );
+        var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER));
         return new SenderExecutableAction(sender, requestManager, errorMessage);
     }
 }

+ 3 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java

@@ -9,9 +9,12 @@ package org.elasticsearch.xpack.inference.services.elastic.action;
 
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
 
 public interface ElasticInferenceServiceActionVisitor {
 
     ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model);
 
+    ExecutableAction create(ElasticInferenceServiceRerankModel model);
+
 }

+ 104 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java

@@ -0,0 +1,104 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.elastic.rerank;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.EmptySecretSettings;
+import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SecretSettings;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
+import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel;
+import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+
+public class ElasticInferenceServiceRerankModel extends ElasticInferenceServiceExecutableActionModel {
+
+    private final URI uri;
+
+    public ElasticInferenceServiceRerankModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        Map<String, Object> secrets,
+        ElasticInferenceServiceComponents elasticInferenceServiceComponents,
+        ConfigurationParseContext context
+    ) {
+        this(
+            inferenceEntityId,
+            taskType,
+            service,
+            ElasticInferenceServiceRerankServiceSettings.fromMap(serviceSettings, context),
+            EmptyTaskSettings.INSTANCE,
+            EmptySecretSettings.INSTANCE,
+            elasticInferenceServiceComponents
+        );
+    }
+
+    public ElasticInferenceServiceRerankModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        ElasticInferenceServiceRerankServiceSettings serviceSettings,
+        @Nullable TaskSettings taskSettings,
+        @Nullable SecretSettings secretSettings,
+        ElasticInferenceServiceComponents elasticInferenceServiceComponents
+    ) {
+        super(
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
+            new ModelSecrets(secretSettings),
+            serviceSettings,
+            elasticInferenceServiceComponents
+        );
+        this.uri = createUri();
+    }
+
+    @Override
+    public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map<String, Object> taskSettings) {
+        return visitor.create(this);
+    }
+
+    @Override
+    public ElasticInferenceServiceRerankServiceSettings getServiceSettings() {
+        return (ElasticInferenceServiceRerankServiceSettings) super.getServiceSettings();
+    }
+
+    public URI uri() {
+        return uri;
+    }
+
+    private URI createUri() throws ElasticsearchStatusException {
+        try {
+            // TODO, consider transforming the base URL into a URI for better error handling.
+            return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank");
+        } catch (URISyntaxException e) {
+            throw new ElasticsearchStatusException(
+                "Failed to create URI for service ["
+                    + this.getConfigurations().getService()
+                    + "] with taskType ["
+                    + this.getTaskType()
+                    + "]: "
+                    + e.getMessage(),
+                RestStatus.BAD_REQUEST,
+                e
+            );
+        }
+    }
+}

+ 126 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java

@@ -0,0 +1,126 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.elastic.rerank;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
+import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+
+public class ElasticInferenceServiceRerankServiceSettings extends FilteredXContentObject
+    implements
+        ServiceSettings,
+        ElasticInferenceServiceRateLimitServiceSettings {
+
+    public static final String NAME = "elastic_rerank_service_settings";
+
+    private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500);
+
+    public static ElasticInferenceServiceRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
+        ValidationException validationException = new ValidationException();
+
+        String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        RateLimitSettings rateLimitSettings = RateLimitSettings.of(
+            map,
+            DEFAULT_RATE_LIMIT_SETTINGS,
+            validationException,
+            ElasticInferenceService.NAME,
+            context
+        );
+
+        return new ElasticInferenceServiceRerankServiceSettings(modelId, rateLimitSettings);
+    }
+
+    private final String modelId;
+
+    private final RateLimitSettings rateLimitSettings;
+
+    public ElasticInferenceServiceRerankServiceSettings(String modelId, RateLimitSettings rateLimitSettings) {
+        this.modelId = Objects.requireNonNull(modelId);
+        this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
+    }
+
+    public ElasticInferenceServiceRerankServiceSettings(StreamInput in) throws IOException {
+        this.modelId = in.readString();
+        this.rateLimitSettings = new RateLimitSettings(in);
+    }
+
+    @Override
+    public String modelId() {
+        return modelId;
+    }
+
+    @Override
+    public RateLimitSettings rateLimitSettings() {
+        return rateLimitSettings;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_ELASTIC_RERANK_ADDED_8_19;
+    }
+
+    @Override
+    protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
+        builder.field(MODEL_ID, modelId);
+        rateLimitSettings.toXContent(builder, params);
+
+        return builder;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        toXContentFragmentOfExposedFields(builder, params);
+
+        builder.endObject();
+
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(modelId);
+        rateLimitSettings.writeTo(out);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        ElasticInferenceServiceRerankServiceSettings that = (ElasticInferenceServiceRerankServiceSettings) o;
+        return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(modelId, rateLimitSettings);
+    }
+}

+ 122 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java

@@ -0,0 +1,122 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.elastic;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequestEntity;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
+
+public class ElasticInferenceServiceRerankRequestEntityTests extends ESTestCase {
+
+    public void testToXContent_SingleDocument_NoTopN() throws IOException {
+        var entity = new ElasticInferenceServiceRerankRequestEntity("query", List.of("document 1"), "rerank-model-id", null);
+        String xContentString = xContentEntityToString(entity);
+        assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
+            {
+                "query": "query",
+                "model": "rerank-model-id",
+                "documents": ["document 1"]
+            }"""));
+    }
+
+    public void testToXContent_MultipleDocuments_NoTopN() throws IOException {
+        var entity = new ElasticInferenceServiceRerankRequestEntity(
+            "query",
+            List.of("document 1", "document 2", "document 3"),
+            "rerank-model-id",
+            null
+        );
+        String xContentString = xContentEntityToString(entity);
+        assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
+            {
+                "query": "query",
+                "model": "rerank-model-id",
+                "documents": [
+                    "document 1",
+                    "document 2",
+                    "document 3"
+                ]
+            }
+            """));
+    }
+
+    public void testToXContent_SingleDocument_WithTopN() throws IOException {
+        var entity = new ElasticInferenceServiceRerankRequestEntity("query", List.of("document 1"), "rerank-model-id", 3);
+        String xContentString = xContentEntityToString(entity);
+        assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
+            {
+                "query": "query",
+                "model": "rerank-model-id",
+                "top_n": 3,
+                "documents": ["document 1"]
+            }
+            """));
+    }
+
+    public void testToXContent_MultipleDocuments_WithTopN() throws IOException {
+        var entity = new ElasticInferenceServiceRerankRequestEntity(
+            "query",
+            List.of("document 1", "document 2", "document 3", "document 4", "document 5"),
+            "rerank-model-id",
+            3
+        );
+        String xContentString = xContentEntityToString(entity);
+        assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
+            {
+                "query": "query",
+                "model": "rerank-model-id",
+                "top_n": 3,
+                "documents": [
+                    "document 1",
+                    "document 2",
+                    "document 3",
+                    "document 4",
+                    "document 5"
+                ]
+            }
+            """));
+    }
+
+    public void testNullQueryThrowsException() {
+        NullPointerException e = expectThrows(
+            NullPointerException.class,
+            () -> new ElasticInferenceServiceRerankRequestEntity(null, List.of("document 1"), "model-id", null)
+        );
+        assertNotNull(e);
+    }
+
+    public void testNullDocumentsThrowsException() {
+        NullPointerException e = expectThrows(
+            NullPointerException.class,
+            () -> new ElasticInferenceServiceRerankRequestEntity("query", null, "model-id", null)
+        );
+        assertNotNull(e);
+    }
+
+    public void testNullModelIdThrowsException() {
+        NullPointerException e = expectThrows(
+            NullPointerException.class,
+            () -> new ElasticInferenceServiceRerankRequestEntity("query", List.of("document 1"), null, null)
+        );
+        assertNotNull(e);
+    }
+
+    private String xContentEntityToString(ElasticInferenceServiceRerankRequestEntity entity) throws IOException {
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        return Strings.toString(builder);
+    }
+}

+ 89 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java

@@ -0,0 +1,89 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.elastic;
+
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
+import org.elasticsearch.xpack.inference.telemetry.TraceContext;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
+import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class ElasticInferenceServiceRerankRequestTests extends ESTestCase {
+
+    public void testTraceContextPropagatedThroughHTTPHeaders() {
+        var url = "http://eis-gateway.com";
+        var query = "query";
+        var documents = List.of("document 1", "document 2", "document 3");
+        var modelId = "my-model-id";
+        var topN = 3;
+
+        var request = createRequest(url, modelId, query, documents, topN);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        var traceParent = request.getTraceContext().traceParent();
+        var traceState = request.getTraceContext().traceState();
+
+        assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent));
+        assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState));
+    }
+
+    public void testTruncate_DoesNotTruncate() throws IOException {
+        var url = "http://eis-gateway.com";
+        var query = "query";
+        var documents = List.of("document 1", "document 2", "document 3");
+        var modelId = "my-model-id";
+        var topN = 3;
+
+        var request = createRequest(url, modelId, query, documents, topN);
+        var truncatedRequest = request.truncate();
+
+        var httpRequest = truncatedRequest.createHttpRequest();
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(4));
+        assertThat(requestMap.get("query"), is(query));
+        assertThat(requestMap.get("model"), is(modelId));
+        assertThat(requestMap.get("documents"), is(documents));
+        assertThat(requestMap.get("top_n"), is(topN));
+    }
+
+    private ElasticInferenceServiceRerankRequest createRequest(
+        String url,
+        String modelId,
+        String query,
+        List<String> documents,
+        Integer topN
+    ) {
+        var rerankModel = ElasticInferenceServiceRerankModelTests.createModel(url, modelId);
+
+        return new ElasticInferenceServiceRerankRequest(
+            query,
+            documents,
+            topN,
+            rerankModel,
+            new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
+            randomElasticInferenceServiceRequestMetadata()
+        );
+    }
+
+}

+ 148 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceRerankResponseEntityTests.java

@@ -0,0 +1,148 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.elastic;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class ElasticInferenceServiceRerankResponseEntityTests extends ESTestCase {
+
+    public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
+        String responseJson = """
+            {
+                "results": [
+                    {
+                        "index": 0,
+                        "relevance_score": 0.94
+                    }
+                ]
+            }
+            """;
+
+        RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.94F, null))));
+    }
+
+    public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
+        String responseJson = """
+            {
+                "results": [
+                    {
+                        "index": 0,
+                        "relevance_score": 0.94
+                    },
+                    {
+                        "index": 1,
+                        "relevance_score": 0.78
+                    },
+                    {
+                        "index": 2,
+                        "relevance_score": 0.65
+                    }
+                ]
+            }
+            """;
+
+        RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults.getRankedDocs(),
+            is(
+                List.of(
+                    new RankedDocsResults.RankedDoc(0, 0.94F, null),
+                    new RankedDocsResults.RankedDoc(1, 0.78F, null),
+                    new RankedDocsResults.RankedDoc(2, 0.65F, null)
+                )
+            )
+        );
+    }
+
+    public void testFromResponse_HandlesFloatingPointPrecision() throws IOException {
+        String responseJson = """
+            {
+                "results": [
+                    {
+                        "index": 0,
+                        "relevance_score": 0.9432156
+                    }
+                ]
+            }
+            """;
+
+        RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.9432156F, null))));
+    }
+
+    public void testFromResponse_OrderIsPreserved() throws IOException {
+        String responseJson = """
+            {
+                "results": [
+                    {
+                        "index": 2,
+                        "relevance_score": 0.94
+                    },
+                    {
+                        "index": 0,
+                        "relevance_score": 0.78
+                    },
+                    {
+                        "index": 1,
+                        "relevance_score": 0.65
+                    }
+                ]
+            }
+            """;
+
+        RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        // Verify the order is maintained from the response
+        assertThat(
+            parsedResults.getRankedDocs(),
+            is(
+                List.of(
+                    new RankedDocsResults.RankedDoc(2, 0.94F, null),
+                    new RankedDocsResults.RankedDoc(0, 0.78F, null),
+                    new RankedDocsResults.RankedDoc(1, 0.65F, null)
+                )
+            )
+        );
+    }
+
+    public void testFromResponse_HandlesEmptyResultsList() throws IOException {
+        String responseJson = """
+            {
+                "results": []
+            }
+            """;
+
+        RankedDocsResults parsedResults = (RankedDocsResults) ElasticInferenceServiceRerankResponseEntity.fromResponse(
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(parsedResults.getRankedDocs(), is(List.of()));
+    }
+}

+ 131 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

@@ -11,6 +11,7 @@ import org.apache.http.HttpHeaders;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.settings.Settings;
@@ -58,6 +59,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
 import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
 import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -149,6 +152,23 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
         }
     }
 
+    public void testParseRequestConfig_CreatesARerankModel() throws IOException {
+        try (var service = createServiceWithMockSender()) {
+            ActionListener<Model> modelListener = ActionListener.wrap(model -> {
+                assertThat(model, instanceOf(ElasticInferenceServiceRerankModel.class));
+                ElasticInferenceServiceRerankModel rerankModel = (ElasticInferenceServiceRerankModel) model;
+                assertThat(rerankModel.getServiceSettings().modelId(), is("my-rerank-model-id"));
+            }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage()));
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.RERANK,
+                getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, "my-rerank-model-id"), Map.of(), Map.of()),
+                modelListener
+            );
+        }
+    }
+
     public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
         try (var service = createServiceWithMockSender()) {
             var config = getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of());
@@ -367,6 +387,39 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
         verifyNoMoreInteractions(sender);
     }
 
+    public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException {
+        var sender = mock(Sender.class);
+
+        var factory = mock(HttpRequestSender.Factory.class);
+        when(factory.createSender()).thenReturn(sender);
+
+        try (var service = createServiceWithMockSender()) {
+            var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), "my-rerank-model-id");
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+
+            var thrownException = expectThrows(
+                ValidationException.class,
+                () -> service.infer(
+                    model,
+                    "search query",
+                    Boolean.TRUE,
+                    10,
+                    List.of("doc1", "doc2", "doc3"),
+                    false,
+                    new HashMap<>(),
+                    InputType.SEARCH,
+                    InferenceAction.Request.DEFAULT_TIMEOUT,
+                    listener
+                )
+            );
+
+            assertThat(
+                thrownException.getMessage(),
+                is("Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this service;")
+            );
+        }
+    }
+
     public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException {
         var sender = mock(Sender.class);
 
@@ -395,7 +448,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                 thrownException.getMessage(),
                 is(
                     "Inference entity [model_id] does not support task type [text_embedding] "
-                        + "for inference, the task type must be one of [sparse_embedding]."
+                        + "for inference, the task type must be one of [sparse_embedding, rerank]."
                 )
             );
 
@@ -436,7 +489,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                 thrownException.getMessage(),
                 is(
                     "Inference entity [model_id] does not support task type [chat_completion] "
-                        + "for inference, the task type must be one of [sparse_embedding]. "
+                        + "for inference, the task type must be one of [sparse_embedding, rerank]. "
                         + "The task type for the inference entity is chat_completion, "
                         + "please use the _inference/chat_completion/model_id/_stream URL."
                 )
@@ -504,6 +557,76 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
         }
     }
 
+    @SuppressWarnings("unchecked")
+    public void testRerank_SendsRerankRequest() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        var elasticInferenceServiceURL = getUrl(webServer);
+
+        try (var service = createService(senderFactory, elasticInferenceServiceURL)) {
+            var modelId = "my-model-id";
+            var topN = 2;
+            String responseJson = """
+                {
+                    "results": [
+                        {"index": 0, "relevance_score": 0.95},
+                        {"index": 1, "relevance_score": 0.85},
+                        {"index": 2, "relevance_score": 0.75}
+                    ]
+                }
+                """;
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = ElasticInferenceServiceRerankModelTests.createModel(elasticInferenceServiceURL, modelId);
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+
+            service.infer(
+                model,
+                "search query",
+                null,
+                topN,
+                List.of("doc1", "doc2", "doc3"),
+                false,
+                new HashMap<>(),
+                InputType.SEARCH,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+            var result = listener.actionGet(TIMEOUT);
+
+            var resultMap = result.asMap();
+            var rerankResults = (List<Map<String, Object>>) resultMap.get("rerank");
+            assertThat(rerankResults.size(), Matchers.is(3));
+
+            Map<String, Object> rankedDocOne = (Map<String, Object>) rerankResults.get(0).get("ranked_doc");
+            Map<String, Object> rankedDocTwo = (Map<String, Object>) rerankResults.get(1).get("ranked_doc");
+            Map<String, Object> rankedDocThree = (Map<String, Object>) rerankResults.get(2).get("ranked_doc");
+
+            assertThat(rankedDocOne.get("index"), equalTo(0));
+            assertThat(rankedDocTwo.get("index"), equalTo(1));
+            assertThat(rankedDocThree.get("index"), equalTo(2));
+
+            // Verify the outgoing HTTP request
+            var request = webServer.requests().get(0);
+            assertNull(request.getUri().getQuery());
+            assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType()));
+
+            // Verify the outgoing request body
+            Map<String, Object> requestMap = entityAsMap(request.getBody());
+            Map<String, Object> expectedRequestMap = Map.of(
+                "query",
+                "search query",
+                "model",
+                modelId,
+                "top_n",
+                topN,
+                "documents",
+                List.of("doc1", "doc2", "doc3")
+            );
+            assertThat(requestMap, is(expectedRequestMap));
+        }
+    }
+
     public void testInfer_PropagatesProductUseCaseHeader() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
         var elasticInferenceServiceURL = getUrl(webServer);
@@ -850,7 +973,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "int",
-                               "supported_task_types": ["sparse_embedding" , "chat_completion"]
+                               "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
                            },
                            "model_id": {
                                "description": "The name of the model to use for the inference task.",
@@ -859,7 +982,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "str",
-                               "supported_task_types": ["sparse_embedding" , "chat_completion"]
+                               "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
                            },
                            "max_input_tokens": {
                                "description": "Allows you to specify the maximum number of tokens per input.",
@@ -905,7 +1028,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "int",
-                               "supported_task_types": ["sparse_embedding" , "chat_completion"]
+                               "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
                            },
                            "model_id": {
                                "description": "The name of the model to use for the inference task.",
@@ -914,7 +1037,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "str",
-                               "supported_task_types": ["sparse_embedding" , "chat_completion"]
+                               "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
                            },
                            "max_input_tokens": {
                                "description": "Allows you to specify the maximum number of tokens per input.",
@@ -974,7 +1097,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "int",
-                               "supported_task_types": ["sparse_embedding" , "chat_completion"]
+                               "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
                            },
                            "model_id": {
                                "description": "The name of the model to use for the inference task.",
@@ -983,7 +1106,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "str",
-                               "supported_task_types": ["sparse_embedding" , "chat_completion"]
+                               "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
                            },
                            "max_input_tokens": {
                                "description": "Allows you to specify the maximum number of tokens per input.",

+ 75 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java

@@ -20,12 +20,15 @@ import org.elasticsearch.test.http.MockWebServer;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
 import org.elasticsearch.xpack.inference.telemetry.TraceContext;
 import org.junit.After;
 import org.junit.Before;
@@ -181,6 +184,78 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
         }
     }
 
+    @SuppressWarnings("unchecked")
+    public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+
+            String responseJson = """
+                {
+                    "results": [
+                        {
+                            "index": 0,
+                            "relevance_score": 0.94
+                        },
+                        {
+                            "index": 1,
+                            "relevance_score": 0.21
+                        }
+                    ]
+                }
+                """;
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var modelId = "my-model-id";
+            var topN = 3;
+            var query = "query";
+            var documents = List.of("document 1", "document 2", "document 3");
+
+            var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), modelId);
+            var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
+            var action = actionCreator.create(model);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new QueryAndDocsInputs(query, documents, null, topN, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(
+                result.asMap(),
+                equalTo(
+                    RankedDocsResultsTests.buildExpectationRerank(
+                        List.of(
+                            new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", 0.94f)),
+                            new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 0.21f))
+                        )
+                    )
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+            assertNull(webServer.requests().get(0).getUri().getQuery());
+            assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+
+            assertThat(requestMap.size(), is(4));
+
+            assertThat(requestMap.get("documents"), instanceOf(List.class));
+            List<String> requestDocuments = (List<String>) requestMap.get("documents");
+            assertThat(requestDocuments.get(0), equalTo(documents.get(0)));
+            assertThat(requestDocuments.get(1), equalTo(documents.get(1)));
+            assertThat(requestDocuments.get(2), equalTo(documents.get(2)));
+
+            assertThat(requestMap.get("top_n"), equalTo(topN));
+
+            assertThat(requestMap.get("query"), equalTo(query));
+
+            assertThat(requestMap.get("model"), equalTo(modelId));
+        }
+    }
+
     public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 

+ 30 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java

@@ -0,0 +1,30 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.elastic.rerank;
+
+import org.elasticsearch.inference.EmptySecretSettings;
+import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
+
+public class ElasticInferenceServiceRerankModelTests extends ESTestCase {
+
+    public static ElasticInferenceServiceRerankModel createModel(String url, String modelId) {
+        return new ElasticInferenceServiceRerankModel(
+            "id",
+            TaskType.RERANK,
+            "service",
+            new ElasticInferenceServiceRerankServiceSettings(modelId, null),
+            EmptyTaskSettings.INSTANCE,
+            EmptySecretSettings.INSTANCE,
+            ElasticInferenceServiceComponents.of(url)
+        );
+    }
+
+}

+ 76 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java

@@ -0,0 +1,76 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.elastic.rerank;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class ElasticInferenceServiceRerankServiceSettingsTests extends AbstractWireSerializingTestCase<
+    ElasticInferenceServiceRerankServiceSettings> {
+
+    @Override
+    protected Writeable.Reader<ElasticInferenceServiceRerankServiceSettings> instanceReader() {
+        return ElasticInferenceServiceRerankServiceSettings::new;
+    }
+
+    @Override
+    protected ElasticInferenceServiceRerankServiceSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected ElasticInferenceServiceRerankServiceSettings mutateInstance(ElasticInferenceServiceRerankServiceSettings instance)
+        throws IOException {
+        return randomValueOtherThan(instance, ElasticInferenceServiceRerankServiceSettingsTests::createRandom);
+    }
+
+    public void testFromMap() {
+        var modelId = "my-model-id";
+
+        var serviceSettings = ElasticInferenceServiceRerankServiceSettings.fromMap(
+            new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)),
+            ConfigurationParseContext.REQUEST
+        );
+
+        assertThat(serviceSettings, is(new ElasticInferenceServiceRerankServiceSettings(modelId, null)));
+    }
+
+    public void testToXContent_WritesAllFields() throws IOException {
+        var modelId = ".rerank-v1";
+        var rateLimitSettings = new RateLimitSettings(100L);
+        var serviceSettings = new ElasticInferenceServiceRerankServiceSettings(modelId, rateLimitSettings);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        serviceSettings.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is(Strings.format("""
+            {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit())));
+    }
+
+    public static ElasticInferenceServiceRerankServiceSettings createRandom() {
+        return new ElasticInferenceServiceRerankServiceSettings(randomRerankModel(), null);
+    }
+
+    private static String randomRerankModel() {
+        return randomFrom(".rerank-v1", ".rerank-v2");
+    }
+}