Browse Source

[ML] Adding elser default endpoint for EIS (#122066)

* Adding elser default endpoint

* [CI] Auto commit changes from spotless

* Fixing test and allowing duplicate calls

* [CI] Auto commit changes from spotless

* Update docs/changelog/122066.yaml

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Jonathan Buttner 8 months ago
parent
commit
b9d122205a

+ 5 - 0
docs/changelog/122066.yaml

@@ -0,0 +1,5 @@
+pr: 122066
+summary: Adding elser default endpoint for EIS
+area: Machine Learning
+type: enhancement
+issues: []

+ 12 - 1
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

@@ -12,10 +12,13 @@ package org.elasticsearch.xpack.inference;
 import org.elasticsearch.inference.TaskType;
 
 import java.io.IOException;
+import java.util.List;
+import java.util.Map;
 
 import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels;
 import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
 
 public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest {
 
@@ -23,12 +26,20 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS
         var allModels = getAllModels();
         var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
 
-        assertThat(allModels, hasSize(4));
+        assertThat(allModels, hasSize(5));
         assertThat(chatCompletionModels, hasSize(1));
 
         for (var model : chatCompletionModels) {
             assertEquals("chat_completion", model.get("task_type"));
         }
 
+        assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
+        assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
+    }
+
+    private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {
+        var model = models.stream().filter(m -> m.get("inference_id").equals(inferenceId)).findFirst();
+        assertTrue("could not find inference id: " + inferenceId, model.isPresent());
+        assertThat(model.get().get("task_type"), is(taskType.toString()));
     }
 }

+ 11 - 3
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

@@ -204,6 +204,7 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
                     service.defaultConfigIds(),
                     is(
                         List.of(
+                            new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
                             new InferenceService.DefaultConfigId(
                                 ".rainbow-sprinkles-elastic",
                                 MinimalServiceSettings.chatCompletion(),
@@ -216,7 +217,8 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
 
                 PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
                 service.defaultConfigs(listener);
-                assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
+                assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
+                assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
 
                 var getModelListener = new PlainActionFuture<UnparsedModel>();
                 // persists the default endpoints
@@ -244,12 +246,18 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
             try (var service = createElasticInferenceService()) {
                 service.waitForAuthorizationToComplete(TIMEOUT);
                 assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
-                assertTrue(service.defaultConfigIds().isEmpty());
+                assertThat(
+                    service.defaultConfigIds(),
+                    is(
+                        List.of(
+                            new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service)
+                        )
+                    )
+                );
                 assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
 
                 var getModelListener = new PlainActionFuture<UnparsedModel>();
                 modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
-
                 var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT));
                 assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]"));
             }

+ 11 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

@@ -126,11 +126,20 @@ public class ModelRegistry {
         return defaultConfigIds.containsKey(inferenceEntityId);
     }
 
+    /**
+     * Adds the default configuration information if it does not already exist internally.
+     * @param defaultConfigId the default endpoint information
+     */
+    public synchronized void putDefaultIdIfAbsent(InferenceService.DefaultConfigId defaultConfigId) {
+        defaultConfigIds.putIfAbsent(defaultConfigId.inferenceId(), defaultConfigId);
+    }
+
     /**
      * Set the default inference ids provided by the services
-     * @param defaultConfigId The default
+     * @param defaultConfigId The default endpoint information
+     * @throws IllegalStateException if the {@link InferenceService.DefaultConfigId#inferenceId()} already exists internally
      */
-    public synchronized void addDefaultIds(InferenceService.DefaultConfigId defaultConfigId) {
+    public synchronized void addDefaultIds(InferenceService.DefaultConfigId defaultConfigId) throws IllegalStateException {
         var config = defaultConfigIds.get(defaultConfigId.inferenceId());
         if (config != null) {
             throw new IllegalStateException(

+ 30 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

@@ -57,6 +57,7 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
 import org.elasticsearch.xpack.inference.telemetry.TraceContext;
 
 import java.util.ArrayList;
+import java.util.Comparator;
 import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -65,6 +66,7 @@ import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.TreeSet;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
@@ -90,14 +92,24 @@ public class ElasticInferenceService extends SenderService {
     private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class);
     private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION);
     private static final String SERVICE_NAME = "Elastic";
+
+    // rainbow-sprinkles
     static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
-    static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
+    static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
+
+    // elser-v2
+    static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
+    static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
 
     /**
      * The task types that the {@link InferenceAction.Request} can accept.
      */
     private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING);
 
+    private static String defaultEndpointId(String modelId) {
+        return Strings.format(".%s-elastic", modelId);
+    }
+
     private final ElasticInferenceServiceComponents elasticInferenceServiceComponents;
     private Configuration configuration;
     private final AtomicReference<AuthorizedContent> authRef = new AtomicReference<>(AuthorizedContent.empty());
@@ -142,6 +154,19 @@ public class ElasticInferenceService extends SenderService {
                     elasticInferenceServiceComponents
                 ),
                 MinimalServiceSettings.chatCompletion()
+            ),
+            DEFAULT_ELSER_MODEL_ID_V2,
+            new DefaultModelConfig(
+                new ElasticInferenceServiceSparseEmbeddingsModel(
+                    DEFAULT_ELSER_ENDPOINT_ID_V2,
+                    TaskType.SPARSE_EMBEDDING,
+                    NAME,
+                    new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
+                    EmptyTaskSettings.INSTANCE,
+                    EmptySecretSettings.INSTANCE,
+                    elasticInferenceServiceComponents
+                ),
+                MinimalServiceSettings.sparseEmbedding()
             )
         );
     }
@@ -184,13 +209,13 @@ public class ElasticInferenceService extends SenderService {
 
         configuration = new Configuration(authRef.get().taskTypesAndModels.getAuthorizedTaskTypes());
 
-        defaultConfigIds().forEach(modelRegistry::addDefaultIds);
+        defaultConfigIds().forEach(modelRegistry::putDefaultIdIfAbsent);
         handleRevokedDefaultConfigs(authorizedDefaultModelIds);
     }
 
     private Set<String> getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorization auth) {
         var authorizedModels = auth.getAuthorizedModelIds();
-        var authorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet());
+        var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet());
         authorizedDefaultModelIds.retainAll(authorizedModels);
 
         return authorizedDefaultModelIds;
@@ -218,6 +243,7 @@ public class ElasticInferenceService extends SenderService {
             }
         }
 
+        authorizedConfigIds.sort(Comparator.comparing(DefaultConfigId::inferenceId));
         return authorizedConfigIds;
     }
 
@@ -230,6 +256,7 @@ public class ElasticInferenceService extends SenderService {
             }
         }
 
+        authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model.getInferenceEntityId()));
         return authorizedModels;
     }
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java

@@ -75,7 +75,7 @@ public class ElasticInferenceServiceSparseEmbeddingsServiceSettings extends Filt
     public ElasticInferenceServiceSparseEmbeddingsServiceSettings(
         String modelId,
         @Nullable Integer maxInputTokens,
-        RateLimitSettings rateLimitSettings
+        @Nullable RateLimitSettings rateLimitSettings
     ) {
         this.modelId = Objects.requireNonNull(modelId);
         this.maxInputTokens = maxInputTokens;

+ 11 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

@@ -934,13 +934,17 @@ public class ElasticInferenceServiceTests extends ESTestCase {
         }
     }
 
-    public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception {
+    public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() throws Exception {
         String responseJson = """
             {
                 "models": [
                     {
                       "model_name": "rainbow-sprinkles",
                       "task_types": ["chat"]
+                    },
+                    {
+                      "model_name": "elser-v2",
+                      "task_types": ["embed/text/sparse"]
                     }
                 ]
             }
@@ -957,15 +961,19 @@ public class ElasticInferenceServiceTests extends ESTestCase {
                 service.defaultConfigIds(),
                 is(
                     List.of(
+                        new InferenceService.DefaultConfigId(".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(), service),
                         new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service)
                     )
                 )
             );
-            assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
+            assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
 
             PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
             service.defaultConfigs(listener);
-            assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
+            var models = listener.actionGet(TIMEOUT);
+            assertThat(models.size(), is(2));
+            assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
+            assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
         }
     }