Bladeren bron

Adding endpoint creation validation for all task types to remaining services (#115020) (#117490)

* Adding endpoint creation validation for all task types to remaining services

* Update Cohere IT tests for rerank validation

* Adding missing import

* Update docs/changelog/115020.yaml

* Fixing GoogleVertex tests after merge from upstream

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Dan Rubinstein 10 maanden geleden
bovenliggende
commit
f5295abbea
17 gewijzigde bestanden met toevoegingen van 473 en 218 verwijderingen
  1. 5 0
      docs/changelog/115020.yaml
  2. 1 0
      x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java
  3. 2 0
      x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java
  4. 23 53
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
  5. 25 39
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java
  6. 7 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
  7. 23 39
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
  8. 24 24
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
  9. 21 36
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
  10. 23 23
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
  11. 40 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java
  12. 74 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
  13. 47 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
  14. 44 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
  15. 34 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java
  16. 29 4
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java
  17. 51 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java

+ 5 - 0
docs/changelog/115020.yaml

@@ -0,0 +1,5 @@
+pr: 115020
+summary: Adding endpoint creation validation for all task types to remaining services
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 0
x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java

@@ -135,6 +135,7 @@ public class CohereServiceMixedIT extends BaseMixedTestCase {
 
         final String inferenceId = "mixed-cluster-rerank";
 
+        cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
         put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
         assertRerank(inferenceId);
 

+ 2 - 0
x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java

@@ -201,6 +201,7 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
         var testTaskType = TaskType.RERANK;
 
         if (isOldCluster()) {
+            cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
             put(oldClusterId, rerankConfig(getUrl(cohereRerankServer)), testTaskType);
             var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier);
             assertThat(configs, hasSize(1));
@@ -229,6 +230,7 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
             assertRerank(oldClusterId);
 
             // New endpoint
+            cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
             put(upgradedClusterId, rerankConfig(getUrl(cohereRerankServer)), testTaskType);
             configs = (List<Map<String, Object>>) get(upgradedClusterId).get("endpoints");
             assertThat(configs, hasSize(1));

+ 23 - 53
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

@@ -18,7 +18,6 @@ import org.elasticsearch.inference.ChunkedInferenceServiceResults;
 import org.elasticsearch.inference.ChunkingOptions;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptySettingsConfiguration;
-import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
@@ -51,6 +50,7 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.Alib
 import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
 import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -60,7 +60,6 @@ import java.util.stream.Stream;
 
 import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
 import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
-import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
@@ -332,68 +331,39 @@ public class AlibabaCloudSearchService extends SenderService {
      */
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
+    }
+
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
         if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
+            var serviceSettings = embeddingsModel.getServiceSettings();
+
+            var updatedServiceSettings = new AlibabaCloudSearchEmbeddingsServiceSettings(
+                new AlibabaCloudSearchServiceSettings(
+                    serviceSettings.getCommonSettings().modelId(),
+                    serviceSettings.getCommonSettings().getHost(),
+                    serviceSettings.getCommonSettings().getWorkspaceName(),
+                    serviceSettings.getCommonSettings().getHttpSchema(),
+                    serviceSettings.getCommonSettings().rateLimitSettings()
+                ),
+                SimilarityMeasure.DOT_PRODUCT,
+                embeddingSize,
+                serviceSettings.getMaxInputTokens()
             );
+
+            return new AlibabaCloudSearchEmbeddingsModel(embeddingsModel, updatedServiceSettings);
         } else {
-            checkAlibabaCloudSearchServiceConfig(model, this, listener);
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
         }
     }
 
-    private AlibabaCloudSearchEmbeddingsModel updateModelWithEmbeddingDetails(AlibabaCloudSearchEmbeddingsModel model, int embeddingSize) {
-        AlibabaCloudSearchEmbeddingsServiceSettings serviceSettings = new AlibabaCloudSearchEmbeddingsServiceSettings(
-            new AlibabaCloudSearchServiceSettings(
-                model.getServiceSettings().getCommonSettings().modelId(),
-                model.getServiceSettings().getCommonSettings().getHost(),
-                model.getServiceSettings().getCommonSettings().getWorkspaceName(),
-                model.getServiceSettings().getCommonSettings().getHttpSchema(),
-                model.getServiceSettings().getCommonSettings().rateLimitSettings()
-            ),
-            SimilarityMeasure.DOT_PRODUCT,
-            embeddingSize,
-            model.getServiceSettings().getMaxInputTokens()
-        );
-
-        return new AlibabaCloudSearchEmbeddingsModel(model, serviceSettings);
-    }
-
     @Override
     public TransportVersion getMinimalSupportedVersion() {
         return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED;
     }
 
-    /**
-     * For other models except of text embedding
-     * check the model's service settings and task settings
-     *
-     * @param model The new model
-     * @param service The inferenceService
-     * @param listener The listener
-     */
-    private void checkAlibabaCloudSearchServiceConfig(Model model, InferenceService service, ActionListener<Model> listener) {
-        String input = ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_INPUT;
-        String query = model.getTaskType().equals(TaskType.RERANK) ? ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_QUERY : null;
-
-        service.infer(
-            model,
-            query,
-            List.of(input),
-            false,
-            Map.of(),
-            InputType.INGEST,
-            DEFAULT_TIMEOUT,
-            listener.delegateFailureAndWrap((delegate, r) -> {
-                listener.onResponse(model);
-            })
-        );
-    }
-
-    private static final String ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_INPUT = "input";
-    private static final String ALIBABA_CLOUD_SEARCH_SERVICE_CONFIG_QUERY = "query";
-
     public static class Configuration {
         public static InferenceServiceConfiguration get() {
             return configuration.getOrCompute();

+ 25 - 39
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

@@ -49,6 +49,7 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.Amazo
 import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.io.IOException;
 import java.util.EnumSet;
@@ -303,49 +304,34 @@ public class AmazonBedrockService extends SenderService {
      */
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
-        if (model instanceof AmazonBedrockEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
-            );
-        } else {
-            listener.onResponse(model);
-        }
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
     }
 
-    private AmazonBedrockEmbeddingsModel updateModelWithEmbeddingDetails(AmazonBedrockEmbeddingsModel model, int embeddingSize) {
-        AmazonBedrockEmbeddingsServiceSettings serviceSettings = model.getServiceSettings();
-        if (serviceSettings.dimensionsSetByUser()
-            && serviceSettings.dimensions() != null
-            && serviceSettings.dimensions() != embeddingSize) {
-            throw new ElasticsearchStatusException(
-                Strings.format(
-                    "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
-                        + "Please recreate the [%s] configuration with the correct dimensions",
-                    embeddingSize,
-                    serviceSettings.dimensions(),
-                    model.getConfigurations().getInferenceEntityId()
-                ),
-                RestStatus.BAD_REQUEST
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
+        if (model instanceof AmazonBedrockEmbeddingsModel embeddingsModel) {
+            var serviceSettings = embeddingsModel.getServiceSettings();
+            var similarityFromModel = serviceSettings.similarity();
+            var similarityToUse = similarityFromModel == null
+                ? getProviderDefaultSimilarityMeasure(embeddingsModel.provider())
+                : similarityFromModel;
+
+            var updatedServiceSettings = new AmazonBedrockEmbeddingsServiceSettings(
+                serviceSettings.region(),
+                serviceSettings.modelId(),
+                serviceSettings.provider(),
+                embeddingSize,
+                serviceSettings.dimensionsSetByUser(),
+                serviceSettings.maxInputTokens(),
+                similarityToUse,
+                serviceSettings.rateLimitSettings()
             );
-        }
-
-        var similarityFromModel = serviceSettings.similarity();
-        var similarityToUse = similarityFromModel == null ? getProviderDefaultSimilarityMeasure(model.provider()) : similarityFromModel;
-
-        AmazonBedrockEmbeddingsServiceSettings settingsToUse = new AmazonBedrockEmbeddingsServiceSettings(
-            serviceSettings.region(),
-            serviceSettings.modelId(),
-            serviceSettings.provider(),
-            embeddingSize,
-            serviceSettings.dimensionsSetByUser(),
-            serviceSettings.maxInputTokens(),
-            similarityToUse,
-            serviceSettings.rateLimitSettings()
-        );
 
-        return new AmazonBedrockEmbeddingsModel(model, settingsToUse);
+            return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedServiceSettings);
+        } else {
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
+        }
     }
 
     private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvider provider) {

+ 7 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

@@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -176,6 +177,12 @@ public class AnthropicService extends SenderService {
         );
     }
 
+    @Override
+    public void checkModelConfig(Model model, ActionListener<Model> listener) {
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
+    }
+
     @Override
     public InferenceServiceConfiguration getConfiguration() {
         return Configuration.get();

+ 23 - 39
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

@@ -11,7 +11,6 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
@@ -46,6 +45,7 @@ import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOp
 import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -294,48 +294,32 @@ public class AzureOpenAiService extends SenderService {
      */
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
-        if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
-            );
-        } else {
-            listener.onResponse(model);
-        }
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
     }
 
-    private AzureOpenAiEmbeddingsModel updateModelWithEmbeddingDetails(AzureOpenAiEmbeddingsModel model, int embeddingSize) {
-        if (model.getServiceSettings().dimensionsSetByUser()
-            && model.getServiceSettings().dimensions() != null
-            && model.getServiceSettings().dimensions() != embeddingSize) {
-            throw new ElasticsearchStatusException(
-                Strings.format(
-                    "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
-                        + "Please recreate the [%s] configuration with the correct dimensions",
-                    embeddingSize,
-                    model.getServiceSettings().dimensions(),
-                    model.getConfigurations().getInferenceEntityId()
-                ),
-                RestStatus.BAD_REQUEST
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
+        if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) {
+            var serviceSettings = embeddingsModel.getServiceSettings();
+            var similarityFromModel = serviceSettings.similarity();
+            var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+
+            var updatedServiceSettings = new AzureOpenAiEmbeddingsServiceSettings(
+                serviceSettings.resourceName(),
+                serviceSettings.deploymentId(),
+                serviceSettings.apiVersion(),
+                embeddingSize,
+                serviceSettings.dimensionsSetByUser(),
+                serviceSettings.maxInputTokens(),
+                similarityToUse,
+                serviceSettings.rateLimitSettings()
             );
-        }
-
-        var similarityFromModel = model.getServiceSettings().similarity();
-        var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
-
-        AzureOpenAiEmbeddingsServiceSettings serviceSettings = new AzureOpenAiEmbeddingsServiceSettings(
-            model.getServiceSettings().resourceName(),
-            model.getServiceSettings().deploymentId(),
-            model.getServiceSettings().apiVersion(),
-            embeddingSize,
-            model.getServiceSettings().dimensionsSetByUser(),
-            model.getServiceSettings().maxInputTokens(),
-            similarityToUse,
-            model.getServiceSettings().rateLimitSettings()
-        );
 
-        return new AzureOpenAiEmbeddingsModel(model, serviceSettings);
+            return new AzureOpenAiEmbeddingsModel(embeddingsModel, updatedServiceSettings);
+        } else {
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
+        }
     }
 
     @Override

+ 24 - 24
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

@@ -45,6 +45,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbedd
 import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;
 import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -293,36 +294,35 @@ public class CohereService extends SenderService {
      */
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
+    }
+
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
         if (model instanceof CohereEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
+            var serviceSettings = embeddingsModel.getServiceSettings();
+            var similarityFromModel = serviceSettings.similarity();
+            var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
+
+            var updatedServiceSettings = new CohereEmbeddingsServiceSettings(
+                new CohereServiceSettings(
+                    serviceSettings.getCommonSettings().uri(),
+                    similarityToUse,
+                    embeddingSize,
+                    serviceSettings.getCommonSettings().maxInputTokens(),
+                    serviceSettings.getCommonSettings().modelId(),
+                    serviceSettings.getCommonSettings().rateLimitSettings()
+                ),
+                serviceSettings.getEmbeddingType()
             );
+
+            return new CohereEmbeddingsModel(embeddingsModel, updatedServiceSettings);
         } else {
-            listener.onResponse(model);
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
         }
     }
 
-    private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel model, int embeddingSize) {
-        var userDefinedSimilarity = model.getServiceSettings().similarity();
-        var similarityToUse = userDefinedSimilarity == null ? defaultSimilarity() : userDefinedSimilarity;
-
-        CohereEmbeddingsServiceSettings serviceSettings = new CohereEmbeddingsServiceSettings(
-            new CohereServiceSettings(
-                model.getServiceSettings().getCommonSettings().uri(),
-                similarityToUse,
-                embeddingSize,
-                model.getServiceSettings().getCommonSettings().maxInputTokens(),
-                model.getServiceSettings().getCommonSettings().modelId(),
-                model.getServiceSettings().getCommonSettings().rateLimitSettings()
-            ),
-            model.getServiceSettings().getEmbeddingType()
-        );
-
-        return new CohereEmbeddingsModel(model, serviceSettings);
-    }
-
     /**
      * Return the default similarity measure for the embedding type.
      * Cohere embeddings are normalized to unit vectors therefor Dot

+ 21 - 36
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

@@ -11,7 +11,6 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
@@ -45,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.Goog
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -181,15 +181,8 @@ public class GoogleVertexAiService extends SenderService {
 
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
-        if (model instanceof GoogleVertexAiEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
-            );
-        } else {
-            listener.onResponse(model);
-        }
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
     }
 
     @Override
@@ -240,34 +233,26 @@ public class GoogleVertexAiService extends SenderService {
         }
     }
 
-    private GoogleVertexAiEmbeddingsModel updateModelWithEmbeddingDetails(GoogleVertexAiEmbeddingsModel model, int embeddingSize) {
-        if (model.getServiceSettings().dimensionsSetByUser()
-            && model.getServiceSettings().dimensions() != null
-            && model.getServiceSettings().dimensions() != embeddingSize) {
-            throw new ElasticsearchStatusException(
-                Strings.format(
-                    "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
-                        + "Please recreate the [%s] configuration with the correct dimensions",
-                    embeddingSize,
-                    model.getServiceSettings().dimensions(),
-                    model.getConfigurations().getInferenceEntityId()
-                ),
-                RestStatus.BAD_REQUEST
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
+        if (model instanceof GoogleVertexAiEmbeddingsModel embeddingsModel) {
+            var serviceSettings = embeddingsModel.getServiceSettings();
+
+            var updatedServiceSettings = new GoogleVertexAiEmbeddingsServiceSettings(
+                serviceSettings.location(),
+                serviceSettings.projectId(),
+                serviceSettings.modelId(),
+                serviceSettings.dimensionsSetByUser(),
+                serviceSettings.maxInputTokens(),
+                embeddingSize,
+                serviceSettings.similarity(),
+                serviceSettings.rateLimitSettings()
             );
-        }
-
-        GoogleVertexAiEmbeddingsServiceSettings serviceSettings = new GoogleVertexAiEmbeddingsServiceSettings(
-            model.getServiceSettings().location(),
-            model.getServiceSettings().projectId(),
-            model.getServiceSettings().modelId(),
-            model.getServiceSettings().dimensionsSetByUser(),
-            model.getServiceSettings().maxInputTokens(),
-            embeddingSize,
-            model.getServiceSettings().similarity(),
-            model.getServiceSettings().rateLimitSettings()
-        );
 
-        return new GoogleVertexAiEmbeddingsModel(model, serviceSettings);
+            return new GoogleVertexAiEmbeddingsModel(embeddingsModel, updatedServiceSettings);
+        } else {
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
+        }
     }
 
     private static GoogleVertexAiModel createModelFromPersistent(

+ 23 - 23
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java

@@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -228,35 +229,34 @@ public class IbmWatsonxService extends SenderService {
 
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
+    }
+
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
         if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
+            var serviceSettings = embeddingsModel.getServiceSettings();
+            var similarityFromModel = serviceSettings.similarity();
+            var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+
+            var updatedServiceSettings = new IbmWatsonxEmbeddingsServiceSettings(
+                serviceSettings.modelId(),
+                serviceSettings.projectId(),
+                serviceSettings.url(),
+                serviceSettings.apiVersion(),
+                serviceSettings.maxInputTokens(),
+                embeddingSize,
+                similarityToUse,
+                serviceSettings.rateLimitSettings()
             );
+
+            return new IbmWatsonxEmbeddingsModel(embeddingsModel, updatedServiceSettings);
         } else {
-            listener.onResponse(model);
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
         }
     }
 
-    private IbmWatsonxEmbeddingsModel updateModelWithEmbeddingDetails(IbmWatsonxEmbeddingsModel model, int embeddingSize) {
-        var similarityFromModel = model.getServiceSettings().similarity();
-        var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
-
-        IbmWatsonxEmbeddingsServiceSettings serviceSettings = new IbmWatsonxEmbeddingsServiceSettings(
-            model.getServiceSettings().modelId(),
-            model.getServiceSettings().projectId(),
-            model.getServiceSettings().url(),
-            model.getServiceSettings().apiVersion(),
-            model.getServiceSettings().maxInputTokens(),
-            embeddingSize,
-            similarityToUse,
-            model.getServiceSettings().rateLimitSettings()
-        );
-
-        return new IbmWatsonxEmbeddingsModel(model, serviceSettings);
-    }
-
     @Override
     protected void doInfer(
         Model model,

+ 40 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
 
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.bytes.BytesArray;
@@ -22,6 +23,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -50,6 +52,7 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.
 import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests;
 import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests;
 import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
+import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests;
 import org.hamcrest.MatcherAssert;
 import org.junit.After;
 import org.junit.Before;
@@ -325,6 +328,43 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = OpenAiChatCompletionModelTests.createChatCompletionModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10)
+            );
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_UpdatesEmbeddingSizeAndSimilarity() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = AlibabaCloudSearchEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomFrom(TaskType.values()),
+                AlibabaCloudSearchEmbeddingsServiceSettingsTests.createRandom(),
+                AlibabaCloudSearchEmbeddingsTaskSettingsTests.createRandom(),
+                null
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            assertEquals(SimilarityMeasure.DOT_PRODUCT, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     public void testChunkedInfer_TextEmbeddingChunkingSettingsSet() throws IOException {
         testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings());
     }

+ 74 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

@@ -50,6 +50,7 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.Amazo
 import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests;
 import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
 import org.hamcrest.CoreMatchers;
 import org.hamcrest.MatcherAssert;
 import org.hamcrest.Matchers;
@@ -72,6 +73,7 @@ import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.c
 import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
 import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
 import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure;
 import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettingsTests.getAmazonBedrockSecretSettingsMap;
 import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettingsTests.createChatCompletionRequestSettingsMap;
 import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap;
@@ -1375,6 +1377,78 @@ public class AmazonBedrockServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var sender = mock(Sender.class);
+        var factory = mock(HttpRequestSender.Factory.class);
+        when(factory.createSender()).thenReturn(sender);
+
+        var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory(
+            ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY),
+            mockClusterServiceEmpty()
+        );
+
+        try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
+            var model = AmazonBedrockChatCompletionModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomFrom(AmazonBedrockProvider.values()),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10)
+            );
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        var sender = mock(Sender.class);
+        var factory = mock(HttpRequestSender.Factory.class);
+        when(factory.createSender()).thenReturn(sender);
+
+        var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory(
+            ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY),
+            mockClusterServiceEmpty()
+        );
+
+        try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var provider = randomFrom(AmazonBedrockProvider.values());
+            var model = AmazonBedrockEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                provider,
+                randomNonNegativeInt(),
+                randomBoolean(),
+                randomNonNegativeInt(),
+                similarityMeasure,
+                RateLimitSettingsTests.createRandom(),
+                createRandomChunkingSettings(),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10)
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null
+                ? getProviderDefaultSimilarityMeasure(provider)
+                : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     public void testInfer_UnauthorizedResponse() throws IOException {
         var sender = mock(Sender.class);
         var factory = mock(HttpRequestSender.Factory.class);

+ 47 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

@@ -1194,6 +1194,53 @@ public class AzureOpenAiServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = AzureOpenAiCompletionModelTests.createModelWithRandomValues();
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = AzureOpenAiEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomNonNegativeInt(),
+                randomBoolean(),
+                randomNonNegativeInt(),
+                similarityMeasure,
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10)
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 

+ 44 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

@@ -1074,6 +1074,50 @@ public class CohereServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = CohereCompletionModelTests.createModel(randomAlphaOfLength(10), randomAlphaOfLength(10), randomAlphaOfLength(10));
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = CohereEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                randomNonNegativeInt(),
+                randomNonNegativeInt(),
+                randomAlphaOfLength(10),
+                randomFrom(CohereEmbeddingType.values()),
+                similarityMeasure
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? CohereService.defaultSimilarity() : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     public void testInfer_UnauthorisedResponse() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 

+ 34 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java

@@ -19,6 +19,7 @@ import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.http.MockWebServer;
@@ -30,9 +31,11 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
 import org.elasticsearch.xpack.inference.services.ServiceFields;
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests;
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
+import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModelTests;
 import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
 import org.hamcrest.CoreMatchers;
 import org.hamcrest.Matchers;
@@ -827,6 +830,37 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        try (var service = createGoogleVertexAiService()) {
+            var model = GoogleVertexAiRerankModelTests.createModel(randomAlphaOfLength(10), randomNonNegativeInt());
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        try (var service = createGoogleVertexAiService()) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = GoogleVertexAiEmbeddingsModelTests.createModel(randomAlphaOfLength(10), randomBoolean(), similarityMeasure);
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     // testInfer tested via end-to-end notebook tests in AppEx repo
 
     @SuppressWarnings("checkstyle:LineLength")

+ 29 - 4
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java

@@ -64,7 +64,7 @@ public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase {
     }
 
     public void testOverrideWith_SetsInputTypeToOverride_WhenFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() {
-        var model = createModel("model", Boolean.FALSE, null);
+        var model = createModel("model", Boolean.FALSE, (InputType) null);
         var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.SEARCH);
 
         var expectedModel = createModel("model", Boolean.FALSE, InputType.SEARCH);
@@ -80,7 +80,7 @@ public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase {
     }
 
     public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() {
-        var model = createModel("model", Boolean.FALSE, null);
+        var model = createModel("model", Boolean.FALSE, (InputType) null);
         var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, InputType.CLUSTERING), InputType.SEARCH);
 
         var expectedModel = createModel("model", Boolean.FALSE, InputType.SEARCH);
@@ -96,10 +96,10 @@ public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase {
     }
 
     public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() {
-        var model = createModel("model", Boolean.FALSE, null);
+        var model = createModel("model", Boolean.FALSE, (InputType) null);
         var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.UNSPECIFIED);
 
-        var expectedModel = createModel("model", Boolean.FALSE, null);
+        var expectedModel = createModel("model", Boolean.FALSE, (InputType) null);
         MatcherAssert.assertThat(overriddenModel, is(expectedModel));
     }
 
@@ -136,6 +136,31 @@ public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase {
         );
     }
 
+    public static GoogleVertexAiEmbeddingsModel createModel(
+        String modelId,
+        @Nullable Boolean autoTruncate,
+        SimilarityMeasure similarityMeasure
+    ) {
+        return new GoogleVertexAiEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "service",
+            new GoogleVertexAiEmbeddingsServiceSettings(
+                randomAlphaOfLength(8),
+                randomAlphaOfLength(8),
+                modelId,
+                false,
+                null,
+                null,
+                similarityMeasure,
+                null
+            ),
+            new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, randomFrom(InputType.INGEST, InputType.SEARCH)),
+            null,
+            new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray()))
+        );
+    }
+
     public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullable Boolean autoTruncate, @Nullable InputType inputType) {
         return new GoogleVertexAiEmbeddingsModel(
             "id",

+ 51 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java

@@ -51,6 +51,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceFields;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests;
 import org.hamcrest.MatcherAssert;
 import org.hamcrest.Matchers;
 import org.junit.After;
@@ -930,6 +931,56 @@ public class IbmWatsonxServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = OpenAiChatCompletionModelTests.createChatCompletionModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10)
+            );
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = IbmWatsonxEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                URI.create(randomAlphaOfLength(10)),
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomNonNegativeInt(),
+                similarityMeasure
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
     public void testGetConfiguration() throws Exception {
         try (var service = createIbmWatsonxService()) {
             String content = XContentHelper.stripWhitespace("""