Преглед изворни кода

[ML] InferenceService support aliases (#128584) (#128595)

"elser" is an alias for "elasticsearch", and "sagemaker" is an alias for
"amazon_sagemaker".

Users can continue to create and use providers by their alias.
Elasticsearch will continue to support the alias when it reads the
configuration from the internal index.
Pat Whelan пре 4 месеци
родитељ
комит
10d873bc7c

+ 5 - 0
docs/changelog/128584.yaml

@@ -0,0 +1,5 @@
+pr: 128584
+summary: '`InferenceService` support aliases'
+area: Machine Learning
+type: enhancement
+issues: []

+ 8 - 0
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -27,6 +27,14 @@ public interface InferenceService extends Closeable {
 
     String name();
 
+    /**
+     * The aliases that map to {@link #name()}. {@link InferenceServiceRegistry} allows users to create and use inference services by one
+     * of their aliases.
+     */
+    default List<String> aliases() {
+        return List.of();
+    }
+
     /**
      * Parse model configuration from the {@code config map} from a request and return
      * the parsed {@link Model}. This requires that both the secrets and service settings be contained in the

+ 8 - 8
server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java

@@ -24,17 +24,22 @@ import java.util.stream.Collectors;
 public class InferenceServiceRegistry implements Closeable {
 
     private final Map<String, InferenceService> services;
+    private final Map<String, String> aliases;
     private final List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
 
     public InferenceServiceRegistry(
         List<InferenceServiceExtension> inferenceServicePlugins,
         InferenceServiceExtension.InferenceServiceFactoryContext factoryContext
     ) {
-        // TODO check names are unique
+        // toMap verifies that the names and aliases are unique
         services = inferenceServicePlugins.stream()
             .flatMap(r -> r.getInferenceServiceFactories().stream())
             .map(factory -> factory.create(factoryContext))
             .collect(Collectors.toMap(InferenceService::name, Function.identity()));
+        aliases = services.values()
+            .stream()
+            .flatMap(service -> service.aliases().stream().distinct().map(alias -> Map.entry(alias, service.name())))
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
     }
 
     public void init(Client client) {
@@ -56,13 +61,8 @@ public class InferenceServiceRegistry implements Closeable {
     }
 
     public Optional<InferenceService> getService(String serviceName) {
-
-        if ("elser".equals(serviceName)) { // ElserService.NAME before removal
-            // here we are aliasing the elser service to use the elasticsearch service instead
-            return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME
-        } else {
-            return Optional.ofNullable(services.get(serviceName));
-        }
+        var serviceKey = aliases.getOrDefault(serviceName, serviceName);
+        return Optional.ofNullable(services.get(serviceKey));
     }
 
     public List<NamedWriteableRegistry.Entry> getNamedWriteables() {

+ 1 - 1
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

@@ -65,7 +65,7 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest {
         var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
         assertDefaultRerankConfig(rerankModel);
 
-        putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
+        putModel("my-model", mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
         var registeredModels = getMinimalConfigs();
         assertThat(registeredModels.size(), equalTo(1));
         assertTrue(registeredModels.containsKey("my-model"));

+ 3 - 3
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -119,12 +119,12 @@ public class InferenceBaseRestTest extends ESRestTestCase {
             """, taskType, apiKey, temperature);
     }
 
-    static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody) {
+    static String mockCompletionServiceModelConfig(@Nullable TaskType taskTypeInBody, String service) {
         var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
         return Strings.format("""
             {
               %s
-              "service": "streaming_completion_test_service",
+              "service": "%s",
               "service_settings": {
                 "model": "my_model",
                 "api_key": "abc64"
@@ -133,7 +133,7 @@ public class InferenceBaseRestTest extends ESRestTestCase {
                 "temperature": 3
               }
             }
-            """, taskType);
+            """, taskType, service);
     }
 
     static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) {

+ 11 - 3
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -305,7 +305,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
 
     public void testUnsupportedStream() throws Exception {
         String modelId = "streaming";
-        putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING));
+        putModel(modelId, mockCompletionServiceModelConfig(TaskType.SPARSE_EMBEDDING, "streaming_completion_test_service"));
         var singleModel = getModel(modelId);
         assertEquals(modelId, singleModel.get("inference_id"));
         assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
@@ -326,8 +326,16 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
     }
 
     public void testSupportedStream() throws Exception {
+        testSupportedStream("streaming_completion_test_service");
+    }
+
+    public void testSupportedStreamForAlias() throws Exception {
+        testSupportedStream("streaming_completion_test_service_alias");
+    }
+
+    private void testSupportedStream(String serviceName) throws Exception {
         String modelId = "streaming";
-        putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION));
+        putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION, serviceName));
         var singleModel = getModel(modelId);
         assertEquals(modelId, singleModel.get("inference_id"));
         assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type"));
@@ -352,7 +360,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
 
     public void testUnifiedCompletionInference() throws Exception {
         String modelId = "streaming";
-        putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION));
+        putModel(modelId, mockCompletionServiceModelConfig(TaskType.CHAT_COMPLETION, "streaming_completion_test_service"));
         var singleModel = getModel(modelId);
         assertEquals(modelId, singleModel.get("inference_id"));
         assertEquals(TaskType.CHAT_COMPLETION.toString(), singleModel.get("task_type"));

+ 4 - 4
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

@@ -54,7 +54,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
                     "text_embedding_test_service",
                     "voyageai",
                     "watsonxai",
-                    "sagemaker"
+                    "amazon_sagemaker"
                 ).toArray()
             )
         );
@@ -93,7 +93,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
                     "text_embedding_test_service",
                     "voyageai",
                     "watsonxai",
-                    "sagemaker"
+                    "amazon_sagemaker"
                 ).toArray()
             )
         );
@@ -143,7 +143,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
                     "openai",
                     "streaming_completion_test_service",
                     "hugging_face",
-                    "sagemaker"
+                    "amazon_sagemaker"
                 ).toArray()
             )
         );
@@ -158,7 +158,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
         assertThat(
             providers,
             containsInAnyOrder(
-                List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "sagemaker").toArray()
+                List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face", "amazon_sagemaker").toArray()
             )
         );
     }

+ 6 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

@@ -60,6 +60,7 @@ public class TestStreamingCompletionServiceExtension implements InferenceService
 
     public static class TestInferenceService extends AbstractTestInferenceService {
         private static final String NAME = "streaming_completion_test_service";
+        private static final String ALIAS = "streaming_completion_test_service_alias";
         private static final Set<TaskType> supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
 
         private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
@@ -75,6 +76,11 @@ public class TestStreamingCompletionServiceExtension implements InferenceService
             return NAME;
         }
 
+        @Override
+        public List<String> aliases() {
+            return List.of(ALIAS);
+        }
+
         @Override
         protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
             return TestServiceSettings.fromMap(serviceSettingsMap);

+ 5 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -778,6 +778,11 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         return NAME;
     }
 
+    @Override
+    public List<String> aliases() {
+        return List.of(OLD_ELSER_SERVICE_NAME);
+    }
+
     private RankedDocsResults textSimilarityResultsToRankedDocs(
         List<? extends InferenceResults> results,
         Function<Integer, String> inputSupplier,

+ 9 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java

@@ -45,7 +45,9 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInva
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails;
 
 public class SageMakerService implements InferenceService {
-    public static final String NAME = "sagemaker";
+    public static final String NAME = "amazon_sagemaker";
+    private static final String DISPLAY_NAME = "Amazon SageMaker";
+    private static final List<String> ALIASES = List.of("sagemaker", "amazonsagemaker");
     private static final int DEFAULT_BATCH_SIZE = 256;
     private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
     private final SageMakerModelBuilder modelBuilder;
@@ -67,7 +69,7 @@ public class SageMakerService implements InferenceService {
         this.threadPool = threadPool;
         this.configuration = new LazyInitializable<>(
             () -> new InferenceServiceConfiguration.Builder().setService(NAME)
-                .setName("Amazon SageMaker")
+                .setName(DISPLAY_NAME)
                 .setTaskTypes(supportedTaskTypes())
                 .setConfigurations(configurationMap.get())
                 .build()
@@ -79,6 +81,11 @@ public class SageMakerService implements InferenceService {
         return NAME;
     }
 
+    @Override
+    public List<String> aliases() {
+        return ALIASES;
+    }
+
     @Override
     public void parseRequestConfig(
         String modelId,