Browse Source

[ML] Default inference endpoint for the multilingual-e5-small model (#114683) (#114779)

David Kyle 1 year ago
parent
commit
ffcf87ce1d

+ 5 - 0
docs/changelog/114683.yaml

@@ -0,0 +1,5 @@
+pr: 114683
+summary: Default inference endpoint for the multilingual-e5-small model
+area: Machine Learning
+type: enhancement
+issues: []

+ 6 - 1
docs/reference/rest-api/usage.asciidoc

@@ -210,7 +210,12 @@ GET /_xpack/usage
         "service": "elasticsearch",
         "task_type": "SPARSE_EMBEDDING",
         "count": 1
-      }
+      },
+      {
+        "service": "elasticsearch",
+        "task_type": "TEXT_EMBEDDING",
+        "count": 1
+      },
     ]
   },
   "logstash" : {

+ 38 - 3
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java → x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

@@ -22,13 +22,13 @@ import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.oneOf;
 
-public class DefaultElserIT extends InferenceBaseRestTest {
+public class DefaultEndPointsIT extends InferenceBaseRestTest {
 
     private TestThreadPool threadPool;
 
     @Before
     public void createThreadPool() {
-        threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName());
+        threadPool = new TestThreadPool(DefaultEndPointsIT.class.getSimpleName());
     }
 
     @After
@@ -38,7 +38,7 @@ public class DefaultElserIT extends InferenceBaseRestTest {
     }
 
     @SuppressWarnings("unchecked")
-    public void testInferCreatesDefaultElser() throws IOException {
+    public void testInferDeploysDefaultElser() throws IOException {
         assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
         var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
         assertDefaultElserConfig(model);
@@ -67,4 +67,39 @@ public class DefaultElserIT extends InferenceBaseRestTest {
             Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
         );
     }
+
+    @SuppressWarnings("unchecked")
+    public void testInferDeploysDefaultE5() throws IOException {
+        assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
+        var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
+        assertDefaultE5Config(model);
+
+        var inputs = List.of("Hello World", "Goodnight moon");
+        var queryParams = Map.of("timeout", "120s");
+        var results = infer(ElasticsearchInternalService.DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, inputs, queryParams);
+        var embeddings = (List<Map<String, Object>>) results.get("text_embedding");
+        assertThat(results.toString(), embeddings, hasSize(2));
+    }
+
+    @SuppressWarnings("unchecked")
+    private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
+        assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_E5_ID, modelConfig.get("inference_id"));
+        assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
+        assertEquals(modelConfig.toString(), TaskType.TEXT_EMBEDDING.toString(), modelConfig.get("task_type"));
+
+        var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
+        assertThat(
+            modelConfig.toString(),
+            serviceSettings.get("model_id"),
+            is(oneOf(".multilingual-e5-small", ".multilingual-e5-small_linux-x86_64"))
+        );
+        assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
+
+        var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
+        assertThat(
+            modelConfig.toString(),
+            adaptiveAllocations,
+            Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
+        );
+    }
 }

+ 3 - 2
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -40,7 +40,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         }
 
         var getAllModels = getAllModels();
-        int numModels = DefaultElserFeatureFlag.isEnabled() ? 10 : 9;
+        int numModels = DefaultElserFeatureFlag.isEnabled() ? 11 : 9;
         assertThat(getAllModels, hasSize(numModels));
 
         var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -51,7 +51,8 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         }
 
         var getDenseModels = getModels("_all", TaskType.TEXT_EMBEDDING);
-        assertThat(getDenseModels, hasSize(4));
+        int numDenseModels = DefaultElserFeatureFlag.isEnabled() ? 5 : 4;
+        assertThat(getDenseModels, hasSize(numDenseModels));
         for (var denseModel : getDenseModels) {
             assertEquals("text_embedding", denseModel.get("task_type"));
         }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

@@ -275,7 +275,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
         return request;
     }
 
-    protected abstract boolean isDefaultId(String inferenceId);
+    abstract boolean isDefaultId(String inferenceId);
 
     protected void maybeStartDeployment(
         ElasticsearchInternalModel model,

+ 20 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -76,6 +76,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
 
     public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
     public static final String DEFAULT_ELSER_ID = ".elser-2";
+    public static final String DEFAULT_E5_ID = ".multi-e5-small";
 
     private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
     private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
@@ -765,7 +766,10 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
     }
 
     public List<DefaultConfigId> defaultConfigIds() {
-        return List.of(new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this));
+        return List.of(
+            new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
+            new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this)
+        );
     }
 
     /**
@@ -805,13 +809,24 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
             ElserMlNodeTaskSettings.DEFAULT,
             null // default chunking settings
         );
-
-        return List.of(defaultElser);
+        var defaultE5 = new MultilingualE5SmallModel(
+            DEFAULT_E5_ID,
+            TaskType.TEXT_EMBEDDING,
+            NAME,
+            new MultilingualE5SmallInternalServiceSettings(
+                null,
+                1,
+                useLinuxOptimizedModel ? MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 : MULTILINGUAL_E5_SMALL_MODEL_ID,
+                new AdaptiveAllocationsSettings(Boolean.TRUE, 1, 8)
+            ),
+            null // default chunking settings
+        );
+        return List.of(defaultElser, defaultE5);
     }
 
     @Override
-    protected boolean isDefaultId(String inferenceId) {
-        return DEFAULT_ELSER_ID.equals(inferenceId);
+    boolean isDefaultId(String inferenceId) {
+        return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
     }
 
     static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(

+ 7 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

@@ -1436,6 +1436,13 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
         assertThat(e.getMessage(), containsString("Chunking is not supported for task type [completion]"));
     }
 
+    public void testIsDefaultId() {
+        var service = createService(mock(Client.class));
+        assertTrue(service.isDefaultId(".elser-2"));
+        assertTrue(service.isDefaultId(".multi-e5-small"));
+        assertFalse(service.isDefaultId("foo"));
+    }
+
     private ElasticsearchInternalService createService(Client client) {
         var cs = mock(ClusterService.class);
         var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

@@ -234,9 +234,9 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
             if (getModelResponse.getResources().results().size() > 1) {
                 listener.onFailure(
                     ExceptionsHelper.badRequestException(
-                        "cannot deploy more than one models at the same time; [{}] matches [{}] models]",
+                        "cannot deploy more than one model at the same time; [{}] matches models [{}]",
                         request.getModelId(),
-                        getModelResponse.getResources().results().size()
+                        getModelResponse.getResources().results().stream().map(TrainedModelConfig::getModelId).toList()
                     )
                 );
                 return;