1
0
Эх сурвалжийг харах

[ML] Fix check on E5 model platform compatibility (#113437)

Creating an endpoint for the built in multilingual e5 model failed for
linux optimised version due to an error in the logic that checks model
compatibility.
David Kyle 1 жил өмнө
parent
commit
3a04f07c50

+ 6 - 0
docs/changelog/113437.yaml

@@ -0,0 +1,6 @@
+pr: 113437
+summary: Fix check on E5 model platform compatibility
+area: Machine Learning
+type: bug
+issues:
+ - 113577

+ 5 - 6
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

@@ -24,7 +24,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
 
     public void testPutE5Small_withNoModelVariant() {
         {
-            String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
+            String inferenceEntityId = "testPutE5Small_withNoModelVariant";
             expectThrows(
                 org.elasticsearch.client.ResponseException.class,
                 () -> putTextEmbeddingModel(inferenceEntityId, noModelIdVariantJsonEntity())
@@ -33,7 +33,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
     }
 
     public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
-        String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
+        String inferenceEntityId = "teste5mall_withplatformagnosticvariant";
         putTextEmbeddingModel(inferenceEntityId, platformAgnosticModelVariantJsonEntity());
         var models = getTrainedModel("_all");
         assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
@@ -50,9 +50,8 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
         deleteTextEmbeddingModel(inferenceEntityId);
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198")
     public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
-        String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
+        String inferenceEntityId = "teste5mall_withplatformspecificvariant";
         if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) {
             putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity());
             var models = getTrainedModel("_all");
@@ -77,7 +76,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
     }
 
     public void testPutE5Small_withFakeModelVariant() {
-        String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
+        String inferenceEntityId = "teste5mall_withfakevariant";
         expectThrows(
             org.elasticsearch.client.ResponseException.class,
             () -> putTextEmbeddingModel(inferenceEntityId, fakeModelVariantJsonEntity())
@@ -112,7 +111,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
     private String noModelIdVariantJsonEntity() {
         return """
                 {
-                  "service": "text_embedding",
+                  "service": "elasticsearch",
                   "service_settings": {
                     "num_allocations": 1,
                     "num_threads": 1

+ 8 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -201,9 +201,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                     MULTILINGUAL_E5_SMALL_MODEL_ID
                 )
             );
-        }
-
-        if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) {
+        } else if (modelVariantValidForArchitecture(platformArchitectures, esServiceSettingsBuilder.getModelId()) == false) {
             throw new IllegalArgumentException(
                 "Error parsing request config, model id does not match any models available on this platform. Was ["
                     + esServiceSettingsBuilder.getModelId()
@@ -224,17 +222,19 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         );
     }
 
-    private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(
-        Set<String> platformArchitectures,
-        String modelId
-    ) {
+    static boolean modelVariantValidForArchitecture(Set<String> platformArchitectures, String modelId) {
+        if (modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID)) {
+            // platform agnostic model is always compatible
+            return true;
+        }
+
         return modelId.equals(
             selectDefaultModelVariantBasedOnClusterArchitecture(
                 platformArchitectures,
                 MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86,
                 MULTILINGUAL_E5_SMALL_MODEL_ID
             )
-        ) && modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID) == false;
+        );
     }
 
     @Override

+ 32 - 15
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

@@ -65,6 +65,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID;
+import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
@@ -167,17 +169,12 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
                         ElasticsearchInternalServiceSettings.NUM_THREADS,
                         4,
                         ElasticsearchInternalServiceSettings.MODEL_ID,
-                        ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID
+                        MULTILINGUAL_E5_SMALL_MODEL_ID
                     )
                 )
             );
 
-            var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(
-                1,
-                4,
-                ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
-                null
-            );
+            var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID, null);
 
             service.parseRequestConfig(
                 randomInferenceEntityId,
@@ -201,7 +198,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
                         ElasticsearchInternalServiceSettings.NUM_THREADS,
                         4,
                         ElasticsearchInternalServiceSettings.MODEL_ID,
-                        ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
+                        MULTILINGUAL_E5_SMALL_MODEL_ID,
                         "not_a_valid_service_setting",
                         randomAlphaOfLength(10)
                     )
@@ -435,19 +432,14 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
                         ElasticsearchInternalServiceSettings.NUM_THREADS,
                         4,
                         ElasticsearchInternalServiceSettings.MODEL_ID,
-                        ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
+                        MULTILINGUAL_E5_SMALL_MODEL_ID,
                         ServiceFields.DIMENSIONS,
                         1
                     )
                 )
             );
 
-            var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(
-                1,
-                4,
-                ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
-                null
-            );
+            var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(1, 4, MULTILINGUAL_E5_SMALL_MODEL_ID, null);
 
             MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig(
                 randomInferenceEntityId,
@@ -950,6 +942,31 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
         assertThat(model, is(expectedModel));
     }
 
+    public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() {
+        {
+            var architectures = Set.of("Aarch64");
+            assertFalse(
+                ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
+            );
+
+            assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
+        }
+        {
+            var architectures = Set.of("linux-x86_64");
+            assertTrue(
+                ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
+            );
+            assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
+        }
+        {
+            var architectures = Set.of("linux-x86_64", "Aarch64");
+            assertFalse(
+                ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
+            );
+            assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
+        }
+    }
+
     private ElasticsearchInternalService createService(Client client) {
         var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
         return new ElasticsearchInternalService(context);