浏览代码

add default inference endpoint for Elastic Inference Service rerank (#129681)

* add Elastic Inference Service rerank default inference endpoint

* [CI] Auto commit changes from spotless

* fix integ tests

* update mock Elastic Inference Service authorization response

* fix rerank service test

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Brendan Jugan 3 月之前
父节点
当前提交
cef717c087

+ 2 - 1
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

@@ -33,7 +33,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS
         var allModels = getAllModels();
         var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
 
-        assertThat(allModels, hasSize(5));
+        assertThat(allModels, hasSize(6));
         assertThat(chatCompletionModels, hasSize(1));
 
         for (var model : chatCompletionModels) {
@@ -42,6 +42,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS
 
         assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
         assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
+        assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
     }
 
     private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {

+ 3 - 2
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

@@ -111,7 +111,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
 
     public void testGetServicesWithRerankTaskType() throws IOException {
         List<Object> services = getServices(TaskType.RERANK);
-        assertThat(services.size(), equalTo(9));
+        assertThat(services.size(), equalTo(10));
 
         var providers = providers(services);
 
@@ -127,7 +127,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
                     "jinaai",
                     "test_reranking_service",
                     "voyageai",
-                    "hugging_face"
+                    "hugging_face",
+                    "elastic"
                 ).toArray()
             )
         );

+ 4 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

@@ -41,6 +41,10 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule
                     {
                       "model_name": "elser-v2",
                       "task_types": ["embed/text/sparse"]
+                    },
+                    {
+                      "model_name": "rerank-v1",
+                      "task_types": ["rerank/text/text-similarity"]
                     }
                 ]
             }

+ 24 - 2
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

@@ -197,6 +197,10 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
                         {
                           "model_name": "elser-v2",
                           "task_types": ["embed/text/sparse"]
+                        },
+                        {
+                          "model_name": "rerank-v1",
+                          "task_types": ["rerank/text/text-similarity"]
                         }
                     ]
                 }
@@ -221,16 +225,25 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
                                 ".rainbow-sprinkles-elastic",
                                 MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
                                 service
+                            ),
+                            new InferenceService.DefaultConfigId(
+                                ".rerank-v1-elastic",
+                                MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
+                                service
                             )
                         )
                     )
                 );
-                assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
+                assertThat(
+                    service.supportedTaskTypes(),
+                    is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
+                );
 
                 PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
                 service.defaultConfigs(listener);
                 assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
                 assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
+                assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
 
                 var getModelListener = new PlainActionFuture<UnparsedModel>();
                 // persists the default endpoints
@@ -248,6 +261,10 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
                         {
                           "model_name": "elser-v2",
                           "task_types": ["embed/text/sparse"]
+                        },
+                        {
+                          "model_name": "rerank-v1",
+                          "task_types": ["rerank/text/text-similarity"]
                         }
                     ]
                 }
@@ -267,11 +284,16 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
                                 ".elser-v2-elastic",
                                 MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
                                 service
+                            ),
+                            new InferenceService.DefaultConfigId(
+                                ".rerank-v1-elastic",
+                                MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
+                                service
                             )
                         )
                     )
                 );
-                assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
+                assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
 
                 var getModelListener = new PlainActionFuture<UnparsedModel>();
                 modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

+ 18 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

@@ -52,6 +52,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
 import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
 import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -95,6 +96,10 @@ public class ElasticInferenceService extends SenderService {
     static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
     static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
 
+    // rerank-v1
+    static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1";
+    static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1);
+
     /**
      * The task types that the {@link InferenceAction.Request} can accept.
      */
@@ -159,6 +164,19 @@ public class ElasticInferenceService extends SenderService {
                     elasticInferenceServiceComponents
                 ),
                 MinimalServiceSettings.sparseEmbedding(NAME)
+            ),
+            DEFAULT_RERANK_MODEL_ID_V1,
+            new DefaultModelConfig(
+                new ElasticInferenceServiceRerankModel(
+                    DEFAULT_RERANK_ENDPOINT_ID_V1,
+                    TaskType.RERANK,
+                    NAME,
+                    new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null),
+                    EmptyTaskSettings.INSTANCE,
+                    EmptySecretSettings.INSTANCE,
+                    elasticInferenceServiceComponents
+                ),
+                MinimalServiceSettings.rerank(NAME)
             )
         );
     }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java

@@ -87,7 +87,7 @@ public class ElasticInferenceServiceRerankModel extends ElasticInferenceServiceE
     private URI createUri() throws ElasticsearchStatusException {
         try {
             // TODO, consider transforming the base URL into a URI for better error handling.
-            return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank");
+            return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank/text/text-similarity");
         } catch (URISyntaxException e) {
             throw new ElasticsearchStatusException(
                 "Failed to create URI for service ["

+ 3 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java

@@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
         "embed/text/sparse",
         TaskType.SPARSE_EMBEDDING,
         "chat",
-        TaskType.CHAT_COMPLETION
+        TaskType.CHAT_COMPLETION,
+        "rerank/text/text-similarity",
+        TaskType.RERANK
     );
 
     @SuppressWarnings("unchecked")

+ 13 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

@@ -1294,6 +1294,10 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                     {
                       "model_name": "elser-v2",
                       "task_types": ["embed/text/sparse"]
+                    },
+                    {
+                      "model_name": "rerank-v1",
+                      "task_types": ["rerank/text/text-similarity"]
                     }
                 ]
             }
@@ -1319,18 +1323,25 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
                             ".rainbow-sprinkles-elastic",
                             MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
                             service
+                        ),
+                        new InferenceService.DefaultConfigId(
+                            ".rerank-v1-elastic",
+                            MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
+                            service
                         )
                     )
                 )
             );
-            assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
+            assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
 
             PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
             service.defaultConfigs(listener);
             var models = listener.actionGet(TIMEOUT);
-            assertThat(models.size(), is(2));
+            assertThat(models.size(), is(3));
             assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
             assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
+            assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
+
         }
     }