Browse Source

[Inference API] Add special case to inference API (#116962) (#117035)

* Add reranker special case to inference API

* Update docs/changelog/116962.yaml

* Update 116962.yaml

* spotless

* improvements from review

* Fix typo
Max Hniebergall 11 months ago
parent
commit
68f3a64ea1

+ 5 - 0
docs/changelog/116962.yaml

@@ -0,0 +1,5 @@
+pr: 116962
+summary: "Add special case for elastic reranker in inference API"
+area: Machine Learning
+type: enhancement
+issues: []

+ 8 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -63,6 +63,7 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServic
 import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
+import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings;
@@ -415,7 +416,13 @@ public class InferenceNamedWriteablesProvider {
                 MultilingualE5SmallInternalServiceSettings::new
             )
         );
-
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                ElasticRerankerServiceSettings.NAME,
+                ElasticRerankerServiceSettings::new
+            )
+        );
     }
 
     private static void addChunkedInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

@@ -156,6 +156,8 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
             putBuiltInModel(e5Model.getServiceSettings().modelId(), listener);
         } else if (model instanceof ElserInternalModel elserModel) {
             putBuiltInModel(elserModel.getServiceSettings().modelId(), listener);
+        } else if (model instanceof ElasticRerankerModel elasticRerankerModel) {
+            putBuiltInModel(elasticRerankerModel.getServiceSettings().modelId(), listener);
         } else if (model instanceof CustomElandModel) {
             logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland.");
             listener.onResponse(Boolean.TRUE);

+ 60 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java

@@ -0,0 +1,60 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.elasticsearch;
+
+import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+public class ElasticRerankerModel extends ElasticsearchInternalModel {
+
+    public ElasticRerankerModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        ElasticRerankerServiceSettings serviceSettings,
+        ChunkingSettings chunkingSettings
+    ) {
+        super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
+    }
+
+    @Override
+    public ElasticRerankerServiceSettings getServiceSettings() {
+        return (ElasticRerankerServiceSettings) super.getServiceSettings();
+    }
+
+    @Override
+    public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
+        Model model,
+        ActionListener<Boolean> listener
+    ) {
+
+        return new ActionListener<>() {
+            @Override
+            public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
+                listener.onResponse(Boolean.TRUE);
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+                    listener.onFailure(
+                        new ResourceNotFoundException("Could not start the Elastic Reranker Endpoint due to [{}]", e, e.getMessage())
+                    );
+                    return;
+                }
+                listener.onFailure(e);
+            }
+        };
+    }
+
+}

+ 62 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java

@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.elasticsearch;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
+
+import java.io.IOException;
+import java.util.Map;
+
+public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {
+
+    public static final String NAME = "elastic_reranker_service_settings";
+
+    public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) {
+        super(other);
+    }
+
+    public ElasticRerankerServiceSettings(
+        Integer numAllocations,
+        int numThreads,
+        String modelId,
+        AdaptiveAllocationsSettings adaptiveAllocationsSettings
+    ) {
+        super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
+    }
+
+    public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    /**
+     * Parse the ElasticRerankerServiceSettings from map and validate the setting values.
+     *
+     * If required setting are missing or the values are invalid an
+     * {@link ValidationException} is thrown.
+     *
+     * @param map Source map containing the config
+     * @return The builder
+     */
+    public static Builder fromRequestMap(Map<String, Object> map) {
+        ValidationException validationException = new ValidationException();
+        var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return baseSettings;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return ElasticRerankerServiceSettings.NAME;
+    }
+}

+ 29 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -97,6 +97,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
     );
 
+    public static final String RERANKER_ID = ".rerank-v1";
+
     public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
     public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch";
     public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch";
@@ -223,6 +225,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                         )
                     )
                 );
+            } else if (RERANKER_ID.equals(modelId)) {
+                rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, chunkingSettings, modelListener);
             } else {
                 customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener);
             }
@@ -323,6 +327,31 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         };
     }
 
+    private void rerankerCase(
+        String inferenceEntityId,
+        TaskType taskType,
+        Map<String, Object> config,
+        Map<String, Object> serviceSettingsMap,
+        ChunkingSettings chunkingSettings,
+        ActionListener<Model> modelListener
+    ) {
+
+        var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap);
+
+        throwIfNotEmptyMap(config, name());
+        throwIfNotEmptyMap(serviceSettingsMap, name());
+
+        modelListener.onResponse(
+            new ElasticRerankerModel(
+                inferenceEntityId,
+                taskType,
+                NAME,
+                new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()),
+                chunkingSettings
+            )
+        );
+    }
+
     private void e5Case(
         String inferenceEntityId,
         TaskType taskType,