1
0
Эх сурвалжийг харах

Adding inference endpoint creation validation for MistralService, GoogleAiStudioService, and HuggingFaceService (#113492) (#113605)

* Adding inference endpoint creation validation for MistralService, GoogleAiStudioService, and HuggingFaceService

* Moving invalid model type exception to shared ServiceUtils function

* Fixing naming inconsistency

* Updating HuggingFaceIT ELSER tests for inference endpoint validation
Dan Rubinstein 1 жил өмнө
parent
commit
e93b481f41
13 өөрчлөгдсөн 213 нэмэгдсэн , 77 устгасан
  1. 1 0
      x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java
  2. 2 0
      x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/HuggingFaceServiceUpgradeIT.java
  3. 7 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
  4. 20 20
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java
  5. 20 22
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
  6. 20 23
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java
  7. 2 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
  8. 16 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java
  9. 39 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
  10. 39 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
  11. 44 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
  12. 3 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
  13. 0 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java

+ 1 - 0
x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java

@@ -84,6 +84,7 @@ public class HuggingFaceServiceMixedIT extends BaseMixedTestCase {
         final String inferenceId = "mixed-cluster-elser";
         final String upgradedClusterId = "upgraded-cluster-elser";
 
+        elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
         put(inferenceId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING);
 
         var configs = (List<Map<String, Object>>) get(TaskType.SPARSE_EMBEDDING, inferenceId).get("endpoints");

+ 2 - 0
x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/HuggingFaceServiceUpgradeIT.java

@@ -117,6 +117,7 @@ public class HuggingFaceServiceUpgradeIT extends InferenceUpgradeTestCase {
         var testTaskType = TaskType.SPARSE_EMBEDDING;
 
         if (isOldCluster()) {
+            elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
             put(oldClusterId, elserConfig(getUrl(elserServer)), testTaskType);
             var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier);
             assertThat(configs, hasSize(1));
@@ -136,6 +137,7 @@ public class HuggingFaceServiceUpgradeIT extends InferenceUpgradeTestCase {
             assertElser(oldClusterId);
 
             // New endpoint
+            elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
             put(upgradedClusterId, elserConfig(getUrl(elserServer)), testTaskType);
             configs = (List<Map<String, Object>>) get(upgradedClusterId).get("endpoints");
             assertThat(configs, hasSize(1));

+ 7 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -202,6 +202,13 @@ public final class ServiceUtils {
         );
     }
 
+    public static ElasticsearchStatusException invalidModelTypeForUpdateModelWithEmbeddingDetails(Class<? extends Model> invalidModelType) {
+        throw new ElasticsearchStatusException(
+            Strings.format("Can't update embedding details for model with unexpected type %s", invalidModelType),
+            RestStatus.BAD_REQUEST
+        );
+    }
+
     public static String missingSettingErrorMsg(String settingName, String scope) {
         return Strings.format("[%s] does not contain the required setting [%s]", scope, settingName);
     }

+ 20 - 20
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java

@@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
 import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.List;
 import java.util.Map;
@@ -187,30 +188,29 @@ public class GoogleAiStudioService extends SenderService {
 
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
-        if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
-            );
-        } else {
-            listener.onResponse(model);
-        }
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
     }
 
-    private GoogleAiStudioEmbeddingsModel updateModelWithEmbeddingDetails(GoogleAiStudioEmbeddingsModel model, int embeddingSize) {
-        var similarityFromModel = model.getServiceSettings().similarity();
-        var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
+        if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) {
+            var serviceSettings = embeddingsModel.getServiceSettings();
+            var similarityFromModel = serviceSettings.similarity();
+            var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
 
-        GoogleAiStudioEmbeddingsServiceSettings serviceSettings = new GoogleAiStudioEmbeddingsServiceSettings(
-            model.getServiceSettings().modelId(),
-            model.getServiceSettings().maxInputTokens(),
-            embeddingSize,
-            similarityToUse,
-            model.getServiceSettings().rateLimitSettings()
-        );
+            var updatedServiceSettings = new GoogleAiStudioEmbeddingsServiceSettings(
+                serviceSettings.modelId(),
+                serviceSettings.maxInputTokens(),
+                embeddingSize,
+                similarityToUse,
+                serviceSettings.rateLimitSettings()
+            );
 
-        return new GoogleAiStudioEmbeddingsModel(model, serviceSettings);
+            return new GoogleAiStudioEmbeddingsModel(embeddingsModel, updatedServiceSettings);
+        } else {
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
+        }
     }
 
     @Override

+ 20 - 22
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

@@ -29,6 +29,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
 import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.List;
 import java.util.Map;
@@ -67,34 +68,31 @@ public class HuggingFaceService extends HuggingFaceBaseService {
 
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
+    }
+
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
         if (model instanceof HuggingFaceEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
+            var serviceSettings = embeddingsModel.getServiceSettings();
+            var similarityFromModel = serviceSettings.similarity();
+            var similarityToUse = similarityFromModel == null ? SimilarityMeasure.COSINE : similarityFromModel;
+
+            var updatedServiceSettings = new HuggingFaceServiceSettings(
+                serviceSettings.uri(),
+                similarityToUse,
+                embeddingSize,
+                embeddingsModel.getTokenLimit(),
+                serviceSettings.rateLimitSettings()
             );
+
+            return new HuggingFaceEmbeddingsModel(embeddingsModel, updatedServiceSettings);
         } else {
-            listener.onResponse(model);
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
         }
     }
 
-    private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
-        // default to cosine similarity
-        var similarity = model.getServiceSettings().similarity() == null
-            ? SimilarityMeasure.COSINE
-            : model.getServiceSettings().similarity();
-
-        var serviceSettings = new HuggingFaceServiceSettings(
-            model.getServiceSettings().uri(),
-            similarity,
-            embeddingSize,
-            model.getTokenLimit(),
-            model.getServiceSettings().rateLimitSettings()
-        );
-
-        return new HuggingFaceEmbeddingsModel(model, serviceSettings);
-    }
-
     @Override
     protected void doChunkedInfer(
         Model model,

+ 20 - 23
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.List;
 import java.util.Map;
@@ -214,32 +215,28 @@ public class MistralService extends SenderService {
 
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
-        if (model instanceof MistralEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateEmbeddingModelConfig(embeddingsModel, size)))
-            );
-        } else {
-            listener.onResponse(model);
-        }
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
     }
 
-    private MistralEmbeddingsModel updateEmbeddingModelConfig(MistralEmbeddingsModel embeddingsModel, int embeddingsSize) {
-        var embeddingServiceSettings = embeddingsModel.getServiceSettings();
-
-        var similarityFromModel = embeddingsModel.getServiceSettings().similarity();
-        var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
+        if (model instanceof MistralEmbeddingsModel embeddingsModel) {
+            var serviceSettings = embeddingsModel.getServiceSettings();
 
-        MistralEmbeddingsServiceSettings serviceSettings = new MistralEmbeddingsServiceSettings(
-            embeddingServiceSettings.modelId(),
-            embeddingsSize,
-            embeddingServiceSettings.maxInputTokens(),
-            similarityToUse,
-            embeddingServiceSettings.rateLimitSettings()
-        );
+            var similarityFromModel = embeddingsModel.getServiceSettings().similarity();
+            var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
 
-        return new MistralEmbeddingsModel(embeddingsModel, serviceSettings);
+            MistralEmbeddingsServiceSettings updatedServiceSettings = new MistralEmbeddingsServiceSettings(
+                serviceSettings.modelId(),
+                embeddingSize,
+                serviceSettings.maxInputTokens(),
+                similarityToUse,
+                serviceSettings.rateLimitSettings()
+            );
+            return new MistralEmbeddingsModel(embeddingsModel, updatedServiceSettings);
+        } else {
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
+        }
     }
-
 }

+ 2 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

@@ -12,7 +12,6 @@ import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.Nullable;
-import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInferenceServiceResults;
 import org.elasticsearch.inference.ChunkingOptions;
@@ -35,6 +34,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.SenderService;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
@@ -307,10 +307,7 @@ public class OpenAiService extends SenderService {
 
             return new OpenAiEmbeddingsModel(embeddingsModel, updatedServiceSettings);
         } else {
-            throw new ElasticsearchStatusException(
-                Strings.format("Can't update embedding details for model with unexpected type %s", model.getClass()),
-                RestStatus.BAD_REQUEST
-            );
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
         }
     }
 

+ 16 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java

@@ -1,3 +1,4 @@
+
 /*
  * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
  * or more contributor license agreements. Licensed under the Elastic License
@@ -34,14 +35,25 @@ public class SimpleServiceIntegrationValidator implements ServiceIntegrationVali
             Map.of(),
             InputType.INGEST,
             InferenceAction.Request.DEFAULT_TIMEOUT,
-            listener.delegateFailureAndWrap((delegate, r) -> {
+            ActionListener.wrap(r -> {
                 if (r != null) {
-                    delegate.onResponse(r);
+                    listener.onResponse(r);
                 } else {
-                    delegate.onFailure(
-                        new ElasticsearchStatusException("Could not make a validation call to the selected service", RestStatus.BAD_REQUEST)
+                    listener.onFailure(
+                        new ElasticsearchStatusException(
+                            "Could not complete inference endpoint creation as validation call to service returned null response.",
+                            RestStatus.BAD_REQUEST
+                        )
                     );
                 }
+            }, e -> {
+                listener.onFailure(
+                    new ElasticsearchStatusException(
+                        "Could not complete inference endpoint creation as validation call to service threw an exception.",
+                        RestStatus.BAD_REQUEST,
+                        e
+                    )
+                );
             })
         );
     }

+ 39 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java

@@ -917,6 +917,45 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = GoogleAiStudioCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10));
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = GoogleAiStudioEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomNonNegativeInt(),
+                similarityMeasure
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     public static Map<String, Object> buildExpectationCompletions(List<String> completions) {
         return Map.of(
             ChatCompletionResults.COMPLETION,

+ 39 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

@@ -597,6 +597,45 @@ public class HuggingFaceServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = HuggingFaceElserModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10));
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = HuggingFaceEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomNonNegativeInt(),
+                randomNonNegativeInt(),
+                similarityMeasure
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.COSINE : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 

+ 44 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java

@@ -30,6 +30,7 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -38,6 +39,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
 import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests;
 import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
 import org.hamcrest.CoreMatchers;
 import org.hamcrest.MatcherAssert;
 import org.hamcrest.Matchers;
@@ -388,6 +390,48 @@ public class MistralServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = new Model(ModelConfigurationsTests.createRandomInstance());
+
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = MistralEmbeddingModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomNonNegativeInt(),
+                randomNonNegativeInt(),
+                similarityMeasure,
+                RateLimitSettingsTests.createRandom()
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws IOException {
         var sender = mock(Sender.class);
 

+ 3 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

@@ -1435,7 +1435,7 @@ public class OpenAiServiceTests extends ESTestCase {
                 randomAlphaOfLength(10),
                 randomAlphaOfLength(10),
                 randomAlphaOfLength(10),
-                null,
+                similarityMeasure,
                 randomNonNegativeInt(),
                 randomNonNegativeInt(),
                 randomBoolean()
@@ -1443,7 +1443,8 @@ public class OpenAiServiceTests extends ESTestCase {
 
             Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
 
-            assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity());
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
             assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
         }
     }

+ 0 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java

@@ -125,7 +125,6 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
             eq(InferenceAction.Request.DEFAULT_TIMEOUT),
             any()
         );
-        verify(mockActionListener).delegateFailureAndWrap(any());
         verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults);
     }
 }