Browse Source

Adding default endpoint for Elastic Rerank (#117939) (#118153)

* Adding default endpoint for Elastic Rerank

* CustomElandRerankTaskSettings -> RerankTaskSettings

* Update docs/changelog/117939.yaml
Ying Mao 10 months ago
parent
commit
59da06bc6d

+ 5 - 0
docs/changelog/117939.yaml

@@ -0,0 +1,5 @@
+pr: 117939
+summary: Adding default endpoint for Elastic Rerank
+area: Machine Learning
+type: enhancement
+issues: []

+ 40 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

@@ -49,6 +49,9 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest {
 
         var e5Model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
         assertDefaultE5Config(e5Model);
+
+        var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
+        assertDefaultRerankConfig(rerankModel);
     }
 
     @SuppressWarnings("unchecked")
@@ -117,6 +120,42 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest {
         assertDefaultChunkingSettings(modelConfig);
     }
 
+    @SuppressWarnings("unchecked")
+    public void testInferDeploysDefaultRerank() throws IOException {
+        var model = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
+        assertDefaultRerankConfig(model);
+
+        var inputs = List.of("Hello World", "Goodnight moon");
+        var query = "but why";
+        var queryParams = Map.of("timeout", "120s");
+        var results = infer(ElasticsearchInternalService.DEFAULT_RERANK_ID, TaskType.RERANK, inputs, query, queryParams);
+        var embeddings = (List<Map<String, Object>>) results.get("rerank");
+        assertThat(results.toString(), embeddings, hasSize(2));
+    }
+
+    @SuppressWarnings("unchecked")
+    private static void assertDefaultRerankConfig(Map<String, Object> modelConfig) {
+        assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_RERANK_ID, modelConfig.get("inference_id"));
+        assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
+        assertEquals(modelConfig.toString(), TaskType.RERANK.toString(), modelConfig.get("task_type"));
+
+        var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
+        assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(".rerank-v1"));
+        assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
+
+        var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
+        assertThat(
+            modelConfig.toString(),
+            adaptiveAllocations,
+            Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32))
+        );
+
+        var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
+        assertNull(chunkingSettings);
+        var taskSettings = (Map<String, Object>) modelConfig.get("task_settings");
+        assertThat(modelConfig.toString(), taskSettings, Matchers.is(Map.of("return_documents", true)));
+    }
+
     @SuppressWarnings("unchecked")
     private static void assertDefaultChunkingSettings(Map<String, Object> modelConfig) {
         var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
@@ -151,6 +190,7 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest {
             var request = createInferenceRequest(
                 Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID),
                 inputs,
+                null,
                 queryParams
             );
             client().performRequestAsync(request, listener);

+ 37 - 10
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -333,7 +333,7 @@ public class InferenceBaseRestTest extends ESRestTestCase {
 
     protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
         var endpoint = Strings.format("_inference/%s", modelId);
-        return inferInternal(endpoint, input, Map.of());
+        return inferInternal(endpoint, input, null, Map.of());
     }
 
     protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
@@ -344,7 +344,7 @@ public class InferenceBaseRestTest extends ESRestTestCase {
     private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
         var responseConsumer = new AsyncInferenceResponseConsumer();
         var request = new Request("POST", endpoint);
-        request.setJsonEntity(jsonBody(input));
+        request.setJsonEntity(jsonBody(input, null));
         request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
         var latch = new CountDownLatch(1);
         client().performRequestAsync(request, new ResponseListener() {
@@ -364,33 +364,60 @@ public class InferenceBaseRestTest extends ESRestTestCase {
 
     protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input) throws IOException {
         var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
-        return inferInternal(endpoint, input, Map.of());
+        return inferInternal(endpoint, input, null, Map.of());
     }
 
     protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input, Map<String, String> queryParameters)
         throws IOException {
         var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
-        return inferInternal(endpoint, input, queryParameters);
+        return inferInternal(endpoint, input, null, queryParameters);
     }
 
-    protected Request createInferenceRequest(String endpoint, List<String> input, Map<String, String> queryParameters) {
+    protected Map<String, Object> infer(
+        String modelId,
+        TaskType taskType,
+        List<String> input,
+        String query,
+        Map<String, String> queryParameters
+    ) throws IOException {
+        var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
+        return inferInternal(endpoint, input, query, queryParameters);
+    }
+
+    protected Request createInferenceRequest(
+        String endpoint,
+        List<String> input,
+        @Nullable String query,
+        Map<String, String> queryParameters
+    ) {
         var request = new Request("POST", endpoint);
-        request.setJsonEntity(jsonBody(input));
+        request.setJsonEntity(jsonBody(input, query));
         if (queryParameters.isEmpty() == false) {
             request.addParameters(queryParameters);
         }
         return request;
     }
 
-    private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
-        var request = createInferenceRequest(endpoint, input, queryParameters);
+    private Map<String, Object> inferInternal(
+        String endpoint,
+        List<String> input,
+        @Nullable String query,
+        Map<String, String> queryParameters
+    ) throws IOException {
+        var request = createInferenceRequest(endpoint, input, query, queryParameters);
         var response = client().performRequest(request);
         assertOkOrCreated(response);
         return entityAsMap(response);
     }
 
-    private String jsonBody(List<String> input) {
-        var bodyBuilder = new StringBuilder("{\"input\": [");
+    private String jsonBody(List<String> input, @Nullable String query) {
+        final StringBuilder bodyBuilder = new StringBuilder("{");
+
+        if (query != null) {
+            bodyBuilder.append("\"query\":\"").append(query).append("\",");
+        }
+
+        bodyBuilder.append("\"input\": [");
         for (var in : input) {
             bodyBuilder.append('"').append(in).append('"').append(',');
         }

+ 2 - 2
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -41,7 +41,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         }
 
         var getAllModels = getAllModels();
-        int numModels = 11;
+        int numModels = 12;
         assertThat(getAllModels, hasSize(numModels));
 
         var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -328,7 +328,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
     }
 
     public void testGetZeroModels() throws IOException {
-        var models = getModels("_all", TaskType.RERANK);
+        var models = getModels("_all", TaskType.COMPLETION);
         assertThat(models, empty());
     }
 }

+ 2 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -62,12 +62,12 @@ import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTask
 import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
-import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings;
+import org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings;
 import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
 import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
@@ -510,9 +510,7 @@ public class InferenceNamedWriteablesProvider {
                 CustomElandInternalTextEmbeddingServiceSettings::new
             )
         );
-        namedWriteables.add(
-            new NamedWriteableRegistry.Entry(TaskSettings.class, CustomElandRerankTaskSettings.NAME, CustomElandRerankTaskSettings::new)
-        );
+        namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, RerankTaskSettings.NAME, RerankTaskSettings::new));
     }
 
     private static void addAnthropicNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankModel.java

@@ -17,7 +17,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
-import static org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings.RETURN_DOCUMENTS;
+import static org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings.RETURN_DOCUMENTS;
 
 public class CustomElandRerankModel extends CustomElandModel {
 
@@ -26,7 +26,7 @@ public class CustomElandRerankModel extends CustomElandModel {
         TaskType taskType,
         String service,
         CustomElandInternalServiceSettings serviceSettings,
-        CustomElandRerankTaskSettings taskSettings
+        RerankTaskSettings taskSettings
     ) {
         super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
     }

+ 2 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java

@@ -9,7 +9,6 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
 
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
@@ -22,9 +21,9 @@ public class ElasticRerankerModel extends ElasticsearchInternalModel {
         TaskType taskType,
         String service,
         ElasticRerankerServiceSettings serviceSettings,
-        ChunkingSettings chunkingSettings
+        RerankTaskSettings taskSettings
     ) {
-        super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
+        super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
     }
 
     @Override

+ 37 - 18
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -101,6 +101,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
     public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
     public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch";
     public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch";
+    public static final String DEFAULT_RERANK_ID = ".rerank-v1-elasticsearch";
 
     private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
         TaskType.RERANK,
@@ -225,7 +226,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                     )
                 );
             } else if (RERANKER_ID.equals(modelId)) {
-                rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, chunkingSettings, modelListener);
+                rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, taskSettingsMap, modelListener);
             } else {
                 customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener);
             }
@@ -308,7 +309,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                 taskType,
                 NAME,
                 elandServiceSettings(serviceSettings, context),
-                CustomElandRerankTaskSettings.fromMap(taskSettings)
+                RerankTaskSettings.fromMap(taskSettings)
             );
             default -> throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
         };
@@ -331,7 +332,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         TaskType taskType,
         Map<String, Object> config,
         Map<String, Object> serviceSettingsMap,
-        ChunkingSettings chunkingSettings,
+        Map<String, Object> taskSettingsMap,
         ActionListener<Model> modelListener
     ) {
 
@@ -346,7 +347,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                 taskType,
                 NAME,
                 new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()),
-                chunkingSettings
+                RerankTaskSettings.fromMap(taskSettingsMap)
             )
         );
     }
@@ -512,6 +513,14 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                 ElserMlNodeTaskSettings.DEFAULT,
                 chunkingSettings
             );
+        } else if (modelId.equals(RERANKER_ID)) {
+            return new ElasticRerankerModel(
+                inferenceEntityId,
+                taskType,
+                NAME,
+                new ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)),
+                RerankTaskSettings.fromMap(taskSettingsMap)
+            );
         } else {
             return createCustomElandModel(
                 inferenceEntityId,
@@ -653,21 +662,23 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
     ) {
         var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
 
-        var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings();
-        var requestSettings = CustomElandRerankTaskSettings.fromMap(requestTaskSettings);
-        Boolean returnDocs = CustomElandRerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
+        var returnDocs = Boolean.TRUE;
+        if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
+            var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
+            returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
+        }
 
         Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
 
-        client.execute(
-            InferModelAction.INSTANCE,
-            request,
-            listener.delegateFailureAndWrap(
-                (l, inferenceResult) -> l.onResponse(
-                    textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)
-                )
-            )
+        ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
+            (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier))
+        );
+
+        var maybeDeployListener = mlResultsListener.delegateResponse(
+            (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener)
         );
+
+        client.execute(InferModelAction.INSTANCE, request, maybeDeployListener);
     }
 
     public void chunkedInfer(
@@ -811,7 +822,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
     public List<DefaultConfigId> defaultConfigIds() {
         return List.of(
             new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
-            new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this)
+            new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this),
+            new DefaultConfigId(DEFAULT_RERANK_ID, TaskType.RERANK, this)
         );
     }
 
@@ -903,12 +915,19 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
             ),
             ChunkingSettingsBuilder.DEFAULT_SETTINGS
         );
-        return List.of(defaultElser, defaultE5);
+        var defaultRerank = new ElasticRerankerModel(
+            DEFAULT_RERANK_ID,
+            TaskType.RERANK,
+            NAME,
+            new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)),
+            RerankTaskSettings.DEFAULT_SETTINGS
+        );
+        return List.of(defaultElser, defaultE5, defaultRerank);
     }
 
     @Override
     boolean isDefaultId(String inferenceId) {
-        return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
+        return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId) || DEFAULT_RERANK_ID.equals(inferenceId);
     }
 
     static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(

+ 11 - 14
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java → x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/RerankTaskSettings.java

@@ -26,14 +26,14 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOpt
 /**
  * Defines the task settings for internal rerank service.
  */
-public class CustomElandRerankTaskSettings implements TaskSettings {
+public class RerankTaskSettings implements TaskSettings {
 
     public static final String NAME = "custom_eland_rerank_task_settings";
     public static final String RETURN_DOCUMENTS = "return_documents";
 
-    static final CustomElandRerankTaskSettings DEFAULT_SETTINGS = new CustomElandRerankTaskSettings(Boolean.TRUE);
+    static final RerankTaskSettings DEFAULT_SETTINGS = new RerankTaskSettings(Boolean.TRUE);
 
-    public static CustomElandRerankTaskSettings defaultsFromMap(Map<String, Object> map) {
+    public static RerankTaskSettings defaultsFromMap(Map<String, Object> map) {
         ValidationException validationException = new ValidationException();
 
         if (map == null || map.isEmpty()) {
@@ -49,7 +49,7 @@ public class CustomElandRerankTaskSettings implements TaskSettings {
             returnDocuments = true;
         }
 
-        return new CustomElandRerankTaskSettings(returnDocuments);
+        return new RerankTaskSettings(returnDocuments);
     }
 
     /**
@@ -57,13 +57,13 @@ public class CustomElandRerankTaskSettings implements TaskSettings {
      * @param map source map
      * @return Task settings
      */
-    public static CustomElandRerankTaskSettings fromMap(Map<String, Object> map) {
+    public static RerankTaskSettings fromMap(Map<String, Object> map) {
         if (map == null || map.isEmpty()) {
             return DEFAULT_SETTINGS;
         }
 
         Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, new ValidationException());
-        return new CustomElandRerankTaskSettings(returnDocuments);
+        return new RerankTaskSettings(returnDocuments);
     }
 
     /**
@@ -74,20 +74,17 @@ public class CustomElandRerankTaskSettings implements TaskSettings {
      * @param requestTaskSettings the settings passed in within the task_settings field of the request
      * @return Either {@code originalSettings} or {@code requestTaskSettings}
      */
-    public static CustomElandRerankTaskSettings of(
-        CustomElandRerankTaskSettings originalSettings,
-        CustomElandRerankTaskSettings requestTaskSettings
-    ) {
+    public static RerankTaskSettings of(RerankTaskSettings originalSettings, RerankTaskSettings requestTaskSettings) {
         return requestTaskSettings.returnDocuments() != null ? requestTaskSettings : originalSettings;
     }
 
     private final Boolean returnDocuments;
 
-    public CustomElandRerankTaskSettings(StreamInput in) throws IOException {
+    public RerankTaskSettings(StreamInput in) throws IOException {
         this(in.readOptionalBoolean());
     }
 
-    public CustomElandRerankTaskSettings(@Nullable Boolean doReturnDocuments) {
+    public RerankTaskSettings(@Nullable Boolean doReturnDocuments) {
         if (doReturnDocuments == null) {
             this.returnDocuments = true;
         } else {
@@ -133,7 +130,7 @@ public class CustomElandRerankTaskSettings implements TaskSettings {
     public boolean equals(Object o) {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
-        CustomElandRerankTaskSettings that = (CustomElandRerankTaskSettings) o;
+        RerankTaskSettings that = (RerankTaskSettings) o;
         return Objects.equals(returnDocuments, that.returnDocuments);
     }
 
@@ -144,7 +141,7 @@ public class CustomElandRerankTaskSettings implements TaskSettings {
 
     @Override
     public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
-        CustomElandRerankTaskSettings updatedSettings = CustomElandRerankTaskSettings.fromMap(new HashMap<>(newSettings));
+        RerankTaskSettings updatedSettings = RerankTaskSettings.fromMap(new HashMap<>(newSettings));
         return of(this, updatedSettings);
     }
 }

+ 23 - 19
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

@@ -534,16 +534,13 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
                 )
             );
             var returnDocs = randomBoolean();
-            settings.put(
-                ModelConfigurations.TASK_SETTINGS,
-                new HashMap<>(Map.of(CustomElandRerankTaskSettings.RETURN_DOCUMENTS, returnDocs))
-            );
+            settings.put(ModelConfigurations.TASK_SETTINGS, new HashMap<>(Map.of(RerankTaskSettings.RETURN_DOCUMENTS, returnDocs)));
 
             ActionListener<Model> modelListener = ActionListener.<Model>wrap(model -> {
                 assertThat(model, instanceOf(CustomElandRerankModel.class));
-                assertThat(model.getTaskSettings(), instanceOf(CustomElandRerankTaskSettings.class));
+                assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class));
                 assertThat(model.getServiceSettings(), instanceOf(CustomElandInternalServiceSettings.class));
-                assertEquals(returnDocs, ((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments());
+                assertEquals(returnDocs, ((RerankTaskSettings) model.getTaskSettings()).returnDocuments());
             }, e -> { fail("Model parsing failed " + e.getMessage()); });
 
             service.parseRequestConfig(randomInferenceEntityId, TaskType.RERANK, settings, modelListener);
@@ -583,9 +580,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 
             ActionListener<Model> modelListener = ActionListener.<Model>wrap(model -> {
                 assertThat(model, instanceOf(CustomElandRerankModel.class));
-                assertThat(model.getTaskSettings(), instanceOf(CustomElandRerankTaskSettings.class));
+                assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class));
                 assertThat(model.getServiceSettings(), instanceOf(CustomElandInternalServiceSettings.class));
-                assertEquals(Boolean.TRUE, ((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments());
+                assertEquals(Boolean.TRUE, ((RerankTaskSettings) model.getTaskSettings()).returnDocuments());
             }, e -> { fail("Model parsing failed " + e.getMessage()); });
 
             service.parseRequestConfig(randomInferenceEntityId, TaskType.RERANK, settings, modelListener);
@@ -1249,14 +1246,11 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
             );
             settings.put(ElasticsearchInternalServiceSettings.MODEL_ID, "foo");
             var returnDocs = randomBoolean();
-            settings.put(
-                ModelConfigurations.TASK_SETTINGS,
-                new HashMap<>(Map.of(CustomElandRerankTaskSettings.RETURN_DOCUMENTS, returnDocs))
-            );
+            settings.put(ModelConfigurations.TASK_SETTINGS, new HashMap<>(Map.of(RerankTaskSettings.RETURN_DOCUMENTS, returnDocs)));
 
             var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.RERANK, settings);
-            assertThat(model.getTaskSettings(), instanceOf(CustomElandRerankTaskSettings.class));
-            assertEquals(returnDocs, ((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments());
+            assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class));
+            assertEquals(returnDocs, ((RerankTaskSettings) model.getTaskSettings()).returnDocuments());
         }
 
         // without task settings
@@ -1279,8 +1273,8 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
             settings.put(ElasticsearchInternalServiceSettings.MODEL_ID, "foo");
 
             var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.RERANK, settings);
-            assertThat(model.getTaskSettings(), instanceOf(CustomElandRerankTaskSettings.class));
-            assertTrue(((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments());
+            assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class));
+            assertTrue(((RerankTaskSettings) model.getTaskSettings()).returnDocuments());
         }
     }
 
@@ -1335,7 +1329,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
                 taskType,
                 ElasticsearchInternalService.NAME,
                 new CustomElandInternalServiceSettings(1, 4, "custom-model", null),
-                CustomElandRerankTaskSettings.DEFAULT_SETTINGS
+                RerankTaskSettings.DEFAULT_SETTINGS
             );
         } else if (taskType == TaskType.TEXT_EMBEDDING) {
             var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "custom-model", null);
@@ -1528,20 +1522,30 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
             )
         );
 
-        var e = expectThrows(
+        var e1 = expectThrows(
             ElasticsearchStatusException.class,
             () -> ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings(
                 TaskType.COMPLETION,
                 new ElasticsearchInternalServiceSettings(1, 1, "foo", null)
             )
         );
-        assertThat(e.getMessage(), containsString("Chunking is not supported for task type [completion]"));
+        assertThat(e1.getMessage(), containsString("Chunking is not supported for task type [completion]"));
+
+        var e2 = expectThrows(
+            ElasticsearchStatusException.class,
+            () -> ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings(
+                TaskType.RERANK,
+                new ElasticsearchInternalServiceSettings(1, 1, "foo", null)
+            )
+        );
+        assertThat(e2.getMessage(), containsString("Chunking is not supported for task type [rerank]"));
     }
 
     public void testIsDefaultId() {
         var service = createService(mock(Client.class));
         assertTrue(service.isDefaultId(".elser-2-elasticsearch"));
         assertTrue(service.isDefaultId(".multilingual-e5-small-elasticsearch"));
+        assertTrue(service.isDefaultId(".rerank-v1-elasticsearch"));
         assertFalse(service.isDefaultId("foo"));
     }
 

+ 24 - 24
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java → x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/RerankTaskSettingsTests.java

@@ -22,7 +22,7 @@ import java.util.Map;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.sameInstance;
 
-public class CustomElandRerankTaskSettingsTests extends AbstractWireSerializingTestCase<CustomElandRerankTaskSettings> {
+public class RerankTaskSettingsTests extends AbstractWireSerializingTestCase<RerankTaskSettings> {
 
     public void testIsEmpty() {
         var randomSettings = createRandom();
@@ -35,9 +35,9 @@ public class CustomElandRerankTaskSettingsTests extends AbstractWireSerializingT
         var newSettings = createRandom();
         Map<String, Object> newSettingsMap = new HashMap<>();
         if (newSettings.returnDocuments() != null) {
-            newSettingsMap.put(CustomElandRerankTaskSettings.RETURN_DOCUMENTS, newSettings.returnDocuments());
+            newSettingsMap.put(RerankTaskSettings.RETURN_DOCUMENTS, newSettings.returnDocuments());
         }
-        CustomElandRerankTaskSettings updatedSettings = (CustomElandRerankTaskSettings) initialSettings.updatedTaskSettings(
+        RerankTaskSettings updatedSettings = (RerankTaskSettings) initialSettings.updatedTaskSettings(
             Collections.unmodifiableMap(newSettingsMap)
         );
         if (newSettings.returnDocuments() == null) {
@@ -48,37 +48,37 @@ public class CustomElandRerankTaskSettingsTests extends AbstractWireSerializingT
     }
 
     public void testDefaultsFromMap_MapIsNull_ReturnsDefaultSettings() {
-        var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(null);
+        var rerankTaskSettings = RerankTaskSettings.defaultsFromMap(null);
 
-        assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS));
+        assertThat(rerankTaskSettings, sameInstance(RerankTaskSettings.DEFAULT_SETTINGS));
     }
 
     public void testDefaultsFromMap_MapIsEmpty_ReturnsDefaultSettings() {
-        var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(new HashMap<>());
+        var rerankTaskSettings = RerankTaskSettings.defaultsFromMap(new HashMap<>());
 
-        assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS));
+        assertThat(rerankTaskSettings, sameInstance(RerankTaskSettings.DEFAULT_SETTINGS));
     }
 
     public void testDefaultsFromMap_ExtractedReturnDocumentsNull_SetsReturnDocumentToTrue() {
-        var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(new HashMap<>());
+        var rerankTaskSettings = RerankTaskSettings.defaultsFromMap(new HashMap<>());
 
-        assertThat(customElandRerankTaskSettings.returnDocuments(), is(Boolean.TRUE));
+        assertThat(rerankTaskSettings.returnDocuments(), is(Boolean.TRUE));
     }
 
     public void testFromMap_MapIsNull_ReturnsDefaultSettings() {
-        var customElandRerankTaskSettings = CustomElandRerankTaskSettings.fromMap(null);
+        var rerankTaskSettings = RerankTaskSettings.fromMap(null);
 
-        assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS));
+        assertThat(rerankTaskSettings, sameInstance(RerankTaskSettings.DEFAULT_SETTINGS));
     }
 
     public void testFromMap_MapIsEmpty_ReturnsDefaultSettings() {
-        var customElandRerankTaskSettings = CustomElandRerankTaskSettings.fromMap(new HashMap<>());
+        var rerankTaskSettings = RerankTaskSettings.fromMap(new HashMap<>());
 
-        assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS));
+        assertThat(rerankTaskSettings, sameInstance(RerankTaskSettings.DEFAULT_SETTINGS));
     }
 
     public void testToXContent_WritesAllValues() throws IOException {
-        var serviceSettings = new CustomElandRerankTaskSettings(Boolean.TRUE);
+        var serviceSettings = new RerankTaskSettings(Boolean.TRUE);
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         serviceSettings.toXContent(builder, null);
@@ -89,30 +89,30 @@ public class CustomElandRerankTaskSettingsTests extends AbstractWireSerializingT
     }
 
     public void testOf_PrefersNonNullRequestTaskSettings() {
-        var originalSettings = new CustomElandRerankTaskSettings(Boolean.FALSE);
-        var requestTaskSettings = new CustomElandRerankTaskSettings(Boolean.TRUE);
+        var originalSettings = new RerankTaskSettings(Boolean.FALSE);
+        var requestTaskSettings = new RerankTaskSettings(Boolean.TRUE);
 
-        var taskSettings = CustomElandRerankTaskSettings.of(originalSettings, requestTaskSettings);
+        var taskSettings = RerankTaskSettings.of(originalSettings, requestTaskSettings);
 
         assertThat(taskSettings, sameInstance(requestTaskSettings));
     }
 
-    private static CustomElandRerankTaskSettings createRandom() {
-        return new CustomElandRerankTaskSettings(randomOptionalBoolean());
+    private static RerankTaskSettings createRandom() {
+        return new RerankTaskSettings(randomOptionalBoolean());
     }
 
     @Override
-    protected Writeable.Reader<CustomElandRerankTaskSettings> instanceReader() {
-        return CustomElandRerankTaskSettings::new;
+    protected Writeable.Reader<RerankTaskSettings> instanceReader() {
+        return RerankTaskSettings::new;
     }
 
     @Override
-    protected CustomElandRerankTaskSettings createTestInstance() {
+    protected RerankTaskSettings createTestInstance() {
         return createRandom();
     }
 
     @Override
-    protected CustomElandRerankTaskSettings mutateInstance(CustomElandRerankTaskSettings instance) throws IOException {
-        return randomValueOtherThan(instance, CustomElandRerankTaskSettingsTests::createRandom);
+    protected RerankTaskSettings mutateInstance(RerankTaskSettings instance) throws IOException {
+        return randomValueOtherThan(instance, RerankTaskSettingsTests::createRandom);
     }
 }