Просмотр исходного кода

Fix ELAND endpoints not updating dimensions (#126537)

* Fix ELAND endpoints not updating dimensions

* Update docs/changelog/126537.yaml
Dan Rubinstein 6 месяцев назад
Родитель
Сommit
44507cce04

+ 5 - 0
docs/changelog/126537.yaml

@@ -0,0 +1,5 @@
+pr: 126537
+summary: Fix ELAND endpoints not updating dimensions
+area: Machine Learning
+type: bug
+issues: []

+ 26 - 18
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -561,25 +561,33 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         }
     }
 
-    private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) {
-        CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
-            model.getServiceSettings().getNumAllocations(),
-            model.getServiceSettings().getNumThreads(),
-            model.getServiceSettings().modelId(),
-            model.getServiceSettings().getAdaptiveAllocationsSettings(),
-            model.getServiceSettings().getDeploymentId(),
-            embeddingSize,
-            model.getServiceSettings().similarity(),
-            model.getServiceSettings().elementType()
-        );
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
+        if (model instanceof CustomElandEmbeddingModel customElandEmbeddingModel && model.getTaskType() == TaskType.TEXT_EMBEDDING) {
+            CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
+                customElandEmbeddingModel.getServiceSettings().getNumAllocations(),
+                customElandEmbeddingModel.getServiceSettings().getNumThreads(),
+                customElandEmbeddingModel.getServiceSettings().modelId(),
+                customElandEmbeddingModel.getServiceSettings().getAdaptiveAllocationsSettings(),
+                customElandEmbeddingModel.getServiceSettings().getDeploymentId(),
+                embeddingSize,
+                customElandEmbeddingModel.getServiceSettings().similarity(),
+                customElandEmbeddingModel.getServiceSettings().elementType()
+            );
+
+            return new CustomElandEmbeddingModel(
+                customElandEmbeddingModel.getInferenceEntityId(),
+                customElandEmbeddingModel.getTaskType(),
+                customElandEmbeddingModel.getConfigurations().getService(),
+                serviceSettings,
+                customElandEmbeddingModel.getConfigurations().getChunkingSettings()
+            );
+        } else if (model instanceof ElasticsearchInternalModel) {
+            return model;
+        } else {
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
+        }
 
-        return new CustomElandEmbeddingModel(
-            model.getInferenceEntityId(),
-            model.getTaskType(),
-            model.getConfigurations().getService(),
-            serviceSettings,
-            model.getConfigurations().getChunkingSettings()
-        );
     }
 
     @Override

+ 70 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

@@ -69,6 +69,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConf
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.InputTypeTests;
+import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
 import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
 import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
 import org.elasticsearch.xpack.inference.services.ServiceFields;
@@ -886,6 +887,75 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() {
+        var service = createService(mock(Client.class));
+        var model = new Model(ModelConfigurationsTests.createRandomInstance());
+
+        assertThrows(ElasticsearchStatusException.class, () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); });
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_TextEmbeddingCustomElandEmbeddingsModelUpdatesDimensions() {
+        var service = createService(mock(Client.class));
+        var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
+            1,
+            4,
+            "invalid",
+            null,
+            null,
+            null,
+            SimilarityMeasure.COSINE,
+            DenseVectorFieldMapper.ElementType.FLOAT
+        );
+        var model = new CustomElandEmbeddingModel(
+            randomAlphaOfLength(10),
+            TaskType.TEXT_EMBEDDING,
+            "elasticsearch",
+            elandServiceSettings,
+            null
+        );
+
+        var embeddingSize = randomNonNegativeInt();
+        var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+        assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonTextEmbeddingCustomElandEmbeddingsModelNotModified() {
+        var service = createService(mock(Client.class));
+        var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(
+            1,
+            4,
+            "invalid",
+            null,
+            null,
+            null,
+            SimilarityMeasure.COSINE,
+            DenseVectorFieldMapper.ElementType.FLOAT
+        );
+        var model = new CustomElandEmbeddingModel(
+            randomAlphaOfLength(10),
+            TaskType.SPARSE_EMBEDDING,
+            "elasticsearch",
+            elandServiceSettings,
+            null
+        );
+
+        var embeddingSize = randomNonNegativeInt();
+        var updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+        assertEquals(model, updatedModel);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_ElasticsearchInternalModelNotModified() {
+        var service = createService(mock(Client.class));
+        var model = mock(ElasticsearchInternalModel.class);
+
+        var updatedModel = service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt());
+
+        assertEquals(model, updatedModel);
+        verifyNoMoreInteractions(model);
+    }
+
     public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException {
         testChunkInfer_e5(null);
     }