فهرست منبع

[ML] Migrate model_version to model_id when parsing persistent elser inference endpoints (#124769)

* Handling model_version for prexisting endpoints

* Update docs/changelog/124769.yaml
Jonathan Buttner 7 ماه پیش
والد
کامیت
bf53f97efd

+ 7 - 0
docs/changelog/124769.yaml

@@ -0,0 +1,7 @@
+pr: 124769
+summary: Migrate `model_version` to `model_id` when parsing persistent elser inference
+  endpoints
+area: Machine Learning
+type: bug
+issues:
+ - 124675

+ 25 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -18,6 +18,7 @@ import org.elasticsearch.common.logging.DeprecationLogger;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.ChunkingSettings;
@@ -111,6 +112,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
     private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
     private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
 
+    /**
+     * Fix for https://github.com/elastic/elasticsearch/issues/124675
+     * In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
+     * service_settings.model_version.
+     */
+    private static final String OLD_MODEL_ID_FIELD_NAME = "model_version";
+
     private final Settings settings;
 
     public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
@@ -489,6 +497,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
         Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
+        migrateModelVersionToModelId(serviceSettingsMap);
+
         ChunkingSettings chunkingSettings = null;
         if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.SPARSE_EMBEDDING.equals(taskType)) {
             chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
@@ -496,7 +506,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
 
         String modelId = (String) serviceSettingsMap.get(MODEL_ID);
         if (modelId == null) {
-            throw new IllegalArgumentException("Error parsing request config, model id is missing");
+            throw new IllegalArgumentException(
+                Strings.format("Error parsing request config, model id is missing for inference id: %s", inferenceEntityId)
+            );
         }
 
         if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
@@ -536,6 +548,18 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         }
     }
 
+    /**
+     * Fix for https://github.com/elastic/elasticsearch/issues/124675
+     * In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
+     * service_settings.model_version. We need to look for that key and migrate it to model_id.
+     */
+    private void migrateModelVersionToModelId(Map<String, Object> serviceSettingsMap) {
+        if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) {
+            String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class);
+            serviceSettingsMap.put(ElserInternalServiceSettings.MODEL_ID, modelId);
+        }
+    }
+
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
         if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) {

+ 27 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

@@ -100,6 +100,7 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.Elasticse
 import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
 import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME;
 import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
@@ -709,6 +710,30 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 
     public void testParsePersistedConfig() {
 
+        // Parsing a persistent configuration using model_version succeeds
+        {
+            var service = createService(mock(Client.class));
+            var settings = new HashMap<String, Object>();
+            settings.put(
+                ModelConfigurations.SERVICE_SETTINGS,
+                new HashMap<>(
+                    Map.of(
+                        ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS,
+                        1,
+                        ElasticsearchInternalServiceSettings.NUM_THREADS,
+                        4,
+                        "model_version",
+                        ".elser_model_2"
+                    )
+                )
+            );
+
+            var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings);
+            assertThat(model, instanceOf(ElserInternalModel.class));
+            ElserInternalModel elserInternalModel = (ElserInternalModel) model;
+            assertThat(elserInternalModel.getServiceSettings().modelId(), is(".elser_model_2"));
+        }
+
         // Null model variant
         {
             var service = createService(mock(Client.class));
@@ -727,11 +752,12 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
                 )
             );
 
-            expectThrows(
+            var exception = expectThrows(
                 IllegalArgumentException.class,
                 () -> service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings)
             );
 
+            assertThat(exception.getMessage(), containsString(randomInferenceEntityId));
         }
 
         // Invalid model variant