Просмотр исходного кода

[ML] Allow users to specify similarity field (#106493)

* Allow users to specify similarity

* Adding l2_norm and e5 fields

* Bumping minimum versions for services

* Cleaning up

* Fixing merge issue

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Jonathan Buttner 1 год назад
Родитель
Сommit
8f28a7a47a
25 измененных файлов с 504 добавлено и 85 удалено
  1. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  2. 25 1
      server/src/main/java/org/elasticsearch/inference/SimilarityMeasure.java
  3. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java
  4. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsExecutableRequestCreator.java
  5. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java
  6. 8 11
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
  7. 8 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
  8. 10 8
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java
  9. 11 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java
  10. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
  11. 28 12
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java
  12. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
  13. 3 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java
  14. 5 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
  15. 4 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java
  16. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java
  17. 5 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java
  18. 154 24
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
  19. 23 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java
  20. 13 8
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
  21. 7 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java
  22. 54 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
  23. 18 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java
  24. 113 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
  25. 7 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -153,6 +153,7 @@ public class TransportVersions {
     public static final TransportVersion USE_DATA_STREAM_GLOBAL_RETENTION = def(8_613_00_0);
     public static final TransportVersion ML_COMPLETION_INFERENCE_SERVICE_ADDED = def(8_614_00_0);
     public static final TransportVersion ML_INFERENCE_EMBEDDING_BYTE_ADDED = def(8_615_00_0);
+    public static final TransportVersion ML_INFERENCE_L2_NORM_SIMILARITY_ADDED = def(8_616_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 25 - 1
server/src/main/java/org/elasticsearch/inference/SimilarityMeasure.java

@@ -8,11 +8,18 @@
 
 package org.elasticsearch.inference;
 
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+
+import java.util.EnumSet;
 import java.util.Locale;
 
 public enum SimilarityMeasure {
     COSINE,
-    DOT_PRODUCT;
+    DOT_PRODUCT,
+    L2_NORM;
+
+    private static final EnumSet<SimilarityMeasure> BEFORE_L2_NORM_ENUMS = EnumSet.range(COSINE, DOT_PRODUCT);
 
     @Override
     public String toString() {
@@ -22,4 +29,21 @@ public enum SimilarityMeasure {
     public static SimilarityMeasure fromString(String name) {
         return valueOf(name.trim().toUpperCase(Locale.ROOT));
     }
+
+    /**
+     * Returns a similarity measure that is known based on the transport version provided. If the similarity enum was not yet
+     * introduced it will be defaulted to null.
+     *
+     * @param similarityMeasure the value to translate if necessary
+     * @param version the version that dictates the translation
+     * @return the similarity that is known to the version passed in
+     */
+    public static SimilarityMeasure translateSimilarity(SimilarityMeasure similarityMeasure, TransportVersion version) {
+        if (version.before(TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED)
+            && BEFORE_L2_NORM_ENUMS.contains(similarityMeasure) == false) {
+            return null;
+        }
+
+        return similarityMeasure;
+    }
 }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsAction.java

@@ -31,7 +31,7 @@ public class CohereEmbeddingsAction implements ExecutableAction {
         Objects.requireNonNull(model);
         this.sender = Objects.requireNonNull(sender);
         this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
-            model.getServiceSettings().getCommonSettings().getUri(),
+            model.getServiceSettings().getCommonSettings().uri(),
             "Cohere embeddings"
         );
         requestCreator = new CohereEmbeddingsExecutableRequestCreator(model);

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsExecutableRequestCreator.java

@@ -37,7 +37,7 @@ public class CohereEmbeddingsExecutableRequestCreator implements ExecutableReque
 
     public CohereEmbeddingsExecutableRequestCreator(CohereEmbeddingsModel model) {
         this.model = Objects.requireNonNull(model);
-        account = new CohereAccount(this.model.getServiceSettings().getCommonSettings().getUri(), this.model.getSecretSettings().apiKey());
+        account = new CohereAccount(this.model.getServiceSettings().getCommonSettings().uri(), this.model.getSecretSettings().apiKey());
     }
 
     @Override

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequest.java

@@ -46,7 +46,7 @@ public class CohereEmbeddingsRequest implements Request {
         this.input = Objects.requireNonNull(input);
         uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri);
         taskSettings = embeddingsModel.getTaskSettings();
-        model = embeddingsModel.getServiceSettings().getCommonSettings().getModelId();
+        model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
         embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
         inferenceEntityId = embeddingsModel.getInferenceEntityId();
     }

+ 8 - 11
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -182,17 +182,14 @@ public class ServiceUtils {
     }
 
     public static SimilarityMeasure extractSimilarity(Map<String, Object> map, String scope, ValidationException validationException) {
-        String similarity = extractOptionalString(map, SIMILARITY, scope, validationException);
-
-        if (similarity != null) {
-            try {
-                return SimilarityMeasure.fromString(similarity);
-            } catch (IllegalArgumentException iae) {
-                validationException.addValidationError("[" + scope + "] Unknown similarity measure [" + similarity + "]");
-            }
-        }
-
-        return null;
+        return extractOptionalEnum(
+            map,
+            SIMILARITY,
+            scope,
+            SimilarityMeasure::fromString,
+            EnumSet.allOf(SimilarityMeasure.class),
+            validationException
+        );
     }
 
     public static String extractRequiredString(

+ 8 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

@@ -216,13 +216,16 @@ public class CohereService extends SenderService {
     }
 
     private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel model, int embeddingSize) {
+        var similarityFromModel = model.getServiceSettings().similarity();
+        var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+
         CohereEmbeddingsServiceSettings serviceSettings = new CohereEmbeddingsServiceSettings(
             new CohereServiceSettings(
-                model.getServiceSettings().getCommonSettings().getUri(),
-                SimilarityMeasure.DOT_PRODUCT,
+                model.getServiceSettings().getCommonSettings().uri(),
+                similarityToUse,
                 embeddingSize,
-                model.getServiceSettings().getCommonSettings().getMaxInputTokens(),
-                model.getServiceSettings().getCommonSettings().getModelId()
+                model.getServiceSettings().getCommonSettings().maxInputTokens(),
+                model.getServiceSettings().getCommonSettings().modelId()
             ),
             model.getServiceSettings().getEmbeddingType()
         );
@@ -232,6 +235,6 @@ public class CohereService extends SenderService {
 
     @Override
     public TransportVersion getMinimalSupportedVersion() {
-        return TransportVersions.ML_INFERENCE_EMBEDDING_BYTE_ADDED;
+        return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
     }
 }

+ 10 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java

@@ -65,10 +65,10 @@ public class CohereServiceSettings implements ServiceSettings {
             throw validationException;
         }
 
-        return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, getModelId(oldModelId, modelId));
+        return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, modelId(oldModelId, modelId));
     }
 
-    private static String getModelId(@Nullable String model, @Nullable String modelId) {
+    private static String modelId(@Nullable String model, @Nullable String modelId) {
         return modelId != null ? modelId : model;
     }
 
@@ -110,23 +110,25 @@ public class CohereServiceSettings implements ServiceSettings {
         modelId = in.readOptionalString();
     }
 
-    public URI getUri() {
+    public URI uri() {
         return uri;
     }
 
-    public SimilarityMeasure getSimilarity() {
+    @Override
+    public SimilarityMeasure similarity() {
         return similarity;
     }
 
-    public Integer getDimensions() {
+    @Override
+    public Integer dimensions() {
         return dimensions;
     }
 
-    public Integer getMaxInputTokens() {
+    public Integer maxInputTokens() {
         return maxInputTokens;
     }
 
-    public String getModelId() {
+    public String modelId() {
         return modelId;
     }
 
@@ -179,7 +181,7 @@ public class CohereServiceSettings implements ServiceSettings {
     public void writeTo(StreamOutput out) throws IOException {
         var uriToWrite = uri != null ? uri.toString() : null;
         out.writeOptionalString(uriToWrite);
-        out.writeOptionalEnum(similarity);
+        out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
         out.writeOptionalVInt(dimensions);
         out.writeOptionalVInt(maxInputTokens);
         out.writeOptionalString(modelId);

+ 11 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -96,6 +97,16 @@ public class CohereEmbeddingsServiceSettings implements ServiceSettings {
         return commonSettings;
     }
 
+    @Override
+    public SimilarityMeasure similarity() {
+        return commonSettings.similarity();
+    }
+
+    @Override
+    public Integer dimensions() {
+        return commonSettings.dimensions();
+    }
+
     public CohereEmbeddingType getEmbeddingType() {
         return embeddingType;
     }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -406,7 +406,7 @@ public class ElasticsearchInternalService implements InferenceService {
 
     @Override
     public TransportVersion getMinimalSupportedVersion() {
-        return TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED;
+        return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
     }
 
     @Override

+ 28 - 12
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java

@@ -11,8 +11,9 @@ import org.elasticsearch.TransportVersion;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
-import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;
 
@@ -26,6 +27,9 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt
 
     public static final String NAME = "multilingual_e5_small_service_settings";
 
+    static final int DIMENSIONS = 384;
+    static final SimilarityMeasure SIMILARITY = SimilarityMeasure.COSINE;
+
     public MultilingualE5SmallInternalServiceSettings(int numAllocations, int numThreads, String modelId) {
         super(numAllocations, numThreads, modelId);
     }
@@ -45,6 +49,16 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt
      */
     public static MultilingualE5SmallInternalServiceSettings.Builder fromMap(Map<String, Object> map) {
         ValidationException validationException = new ValidationException();
+        var requestFields = extractRequestFields(map, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return createBuilder(requestFields);
+    }
+
+    private static RequestFields extractRequestFields(Map<String, Object> map, ValidationException validationException) {
         Integer numAllocations = ServiceUtils.removeAsType(map, NUM_ALLOCATIONS, Integer.class);
         Integer numThreads = ServiceUtils.removeAsType(map, NUM_THREADS, Integer.class);
 
@@ -62,26 +76,23 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt
             }
         }
 
-        if (validationException.validationErrors().isEmpty() == false) {
-            throw validationException;
-        }
+        return new RequestFields(numAllocations, numThreads, modelId);
+    }
 
+    private static MultilingualE5SmallInternalServiceSettings.Builder createBuilder(RequestFields requestFields) {
         var builder = new InternalServiceSettings.Builder() {
             @Override
             public MultilingualE5SmallInternalServiceSettings build() {
                 return new MultilingualE5SmallInternalServiceSettings(getNumAllocations(), getNumThreads(), getModelId());
             }
         };
-        builder.setNumAllocations(numAllocations);
-        builder.setNumThreads(numThreads);
-        builder.setModelId(modelId);
+        builder.setNumAllocations(requestFields.numAllocations);
+        builder.setNumThreads(requestFields.numThreads);
+        builder.setModelId(requestFields.modelId);
         return builder;
     }
 
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        return super.toXContent(builder, params);
-    }
+    private record RequestFields(@Nullable Integer numAllocations, @Nullable Integer numThreads, @Nullable String modelId) {}
 
     @Override
     public boolean isFragment() {
@@ -103,9 +114,14 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt
         super.writeTo(out);
     }
 
+    @Override
+    public SimilarityMeasure similarity() {
+        return SIMILARITY;
+    }
+
     @Override
     public Integer dimensions() {
-        return 384;
+        return DIMENSIONS;
     }
 
     @Override

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

@@ -61,7 +61,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
     private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
         var serviceSettings = new HuggingFaceServiceSettings(
             model.getServiceSettings().uri(),
-            null, // Similarity measure is unknown
+            model.getServiceSettings().similarity(), // we don't know the similarity but use whatever the user specified
             embeddingSize,
             model.getTokenLimit()
         );
@@ -76,6 +76,6 @@ public class HuggingFaceService extends HuggingFaceBaseService {
 
     @Override
     public TransportVersion getMinimalSupportedVersion() {
-        return TransportVersions.V_8_12_0;
+        return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
     }
 }

+ 3 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettings.java

@@ -135,7 +135,7 @@ public class HuggingFaceServiceSettings implements ServiceSettings {
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(uri.toString());
         if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
-            out.writeOptionalEnum(similarity);
+            out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
             out.writeOptionalVInt(dimensions);
             out.writeOptionalVInt(maxInputTokens);
         }
@@ -145,10 +145,12 @@ public class HuggingFaceServiceSettings implements ServiceSettings {
         return uri;
     }
 
+    @Override
     public SimilarityMeasure similarity() {
         return similarity;
     }
 
+    @Override
     public Integer dimensions() {
         return dimensions;
     }

+ 5 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

@@ -248,11 +248,14 @@ public class OpenAiService extends SenderService {
             );
         }
 
+        var similarityFromModel = model.getServiceSettings().similarity();
+        var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+
         OpenAiEmbeddingsServiceSettings serviceSettings = new OpenAiEmbeddingsServiceSettings(
             model.getServiceSettings().modelId(),
             model.getServiceSettings().uri(),
             model.getServiceSettings().organizationId(),
-            SimilarityMeasure.DOT_PRODUCT,
+            similarityToUse,
             embeddingSize,
             model.getServiceSettings().maxInputTokens(),
             model.getServiceSettings().dimensionsSetByUser()
@@ -263,7 +266,7 @@ public class OpenAiService extends SenderService {
 
     @Override
     public TransportVersion getMinimalSupportedVersion() {
-        return TransportVersions.ML_COMPLETION_INFERENCE_SERVICE_ADDED;
+        return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
     }
 
     /**

+ 4 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java

@@ -192,10 +192,12 @@ public class OpenAiEmbeddingsServiceSettings implements ServiceSettings {
         return organizationId;
     }
 
+    @Override
     public SimilarityMeasure similarity() {
         return similarity;
     }
 
+    @Override
     public Integer dimensions() {
         return dimensions;
     }
@@ -277,8 +279,9 @@ public class OpenAiEmbeddingsServiceSettings implements ServiceSettings {
         var uriToWrite = uri != null ? uri.toString() : null;
         out.writeOptionalString(uriToWrite);
         out.writeOptionalString(organizationId);
+
         if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
-            out.writeOptionalEnum(similarity);
+            out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
             out.writeOptionalVInt(dimensions);
             out.writeOptionalVInt(maxInputTokens);
         }

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java

@@ -178,7 +178,7 @@ public class CohereEmbeddingsRequestTests extends ESTestCase {
     }
 
     public static CohereEmbeddingsRequest createRequest(List<String> input, CohereEmbeddingsModel model) {
-        var account = new CohereAccount(model.getServiceSettings().getCommonSettings().getUri(), model.getSecretSettings().apiKey());
+        var account = new CohereAccount(model.getServiceSettings().getCommonSettings().uri(), model.getSecretSettings().apiKey());
         return new CohereEmbeddingsRequest(account, input, model);
     }
 }

+ 5 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java

@@ -151,7 +151,7 @@ public class CohereServiceSettingsTests extends AbstractWireSerializingTestCase<
 
     public void testFromMap_MissingUrl_DoesNotThrowException() {
         var serviceSettings = CohereServiceSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.PERSISTENT);
-        assertNull(serviceSettings.getUri());
+        assertNull(serviceSettings.uri());
     }
 
     public void testFromMap_EmptyUrl_ThrowsError() {
@@ -196,7 +196,10 @@ public class CohereServiceSettingsTests extends AbstractWireSerializingTestCase<
 
         MatcherAssert.assertThat(
             thrownException.getMessage(),
-            is("Validation Failed: 1: [service_settings] Unknown similarity measure [by_size];")
+            is(
+                "Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] "
+                    + "must be one of [cosine, dot_product, l2_norm];"
+            )
         );
     }
 

+ 154 - 24
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.http.MockResponse;
@@ -99,8 +100,8 @@ public class CohereServiceTests extends ESTestCase {
                 MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
                 var embeddingsModel = (CohereEmbeddingsModel) model;
-                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
-                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
+                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
                 MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT));
                 MatcherAssert.assertThat(
                     embeddingsModel.getTaskSettings(),
@@ -131,8 +132,8 @@ public class CohereServiceTests extends ESTestCase {
                 MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
                 var embeddingsModel = (CohereEmbeddingsModel) model;
-                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
-                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
+                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
                 MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT));
                 MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), equalTo(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
                 MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
@@ -257,7 +258,7 @@ public class CohereServiceTests extends ESTestCase {
                 MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
                 var embeddingsModel = (CohereEmbeddingsModel) model;
-                assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
+                assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri());
                 MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
                 MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
             }, (e) -> fail("Model parsing should have succeeded " + e.getMessage()));
@@ -295,8 +296,8 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
             MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
         }
@@ -345,7 +346,7 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
+            assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri());
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)));
             MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
         }
@@ -370,8 +371,8 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
             MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.BYTE));
             MatcherAssert.assertThat(
                 embeddingsModel.getTaskSettings(),
@@ -402,7 +403,7 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
             MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
         }
@@ -427,8 +428,8 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
             MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
         }
@@ -451,7 +452,7 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
             MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
         }
@@ -478,8 +479,8 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null)));
             MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
         }
@@ -497,8 +498,8 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE)));
             assertNull(embeddingsModel.getSecretSettings());
         }
@@ -535,8 +536,8 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+            assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri());
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
             MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
             assertNull(embeddingsModel.getSecretSettings());
@@ -556,7 +557,7 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
             assertNull(embeddingsModel.getSecretSettings());
         }
@@ -574,7 +575,7 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null)));
             assertNull(embeddingsModel.getSecretSettings());
         }
@@ -595,8 +596,8 @@ public class CohereServiceTests extends ESTestCase {
             MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
 
             var embeddingsModel = (CohereEmbeddingsModel) model;
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
-            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
+            MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
             MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)));
             assertNull(embeddingsModel.getSecretSettings());
         }
@@ -755,6 +756,135 @@ public class CohereServiceTests extends ESTestCase {
         }
     }
 
+    public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
+                    "texts": [
+                        "hello"
+                    ],
+                    "embeddings": {
+                        "float": [
+                            [
+                                0.123,
+                                -0.123
+                            ]
+                        ]
+                    },
+                    "meta": {
+                        "api_version": {
+                            "version": "1"
+                        },
+                        "billed_units": {
+                            "input_tokens": 1
+                        }
+                    },
+                    "response_type": "embeddings_by_type"
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = CohereEmbeddingsModelTests.createModel(
+                getUrl(webServer),
+                "secret",
+                CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                10,
+                1,
+                null,
+                null
+            );
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+            var result = listener.actionGet(TIMEOUT);
+
+            MatcherAssert.assertThat(
+                result,
+                // the dimension is set to 2 because there are 2 embeddings returned from the mock server
+                is(
+                    CohereEmbeddingsModelTests.createModel(
+                        getUrl(webServer),
+                        "secret",
+                        CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                        10,
+                        2,
+                        null,
+                        null,
+                        SimilarityMeasure.DOT_PRODUCT
+                    )
+                )
+            );
+        }
+    }
+
+    public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
+                    "texts": [
+                        "hello"
+                    ],
+                    "embeddings": {
+                        "float": [
+                            [
+                                0.123,
+                                -0.123
+                            ]
+                        ]
+                    },
+                    "meta": {
+                        "api_version": {
+                            "version": "1"
+                        },
+                        "billed_units": {
+                            "input_tokens": 1
+                        }
+                    },
+                    "response_type": "embeddings_by_type"
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = CohereEmbeddingsModelTests.createModel(
+                getUrl(webServer),
+                "secret",
+                CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                10,
+                1,
+                null,
+                null,
+                SimilarityMeasure.COSINE
+            );
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+            var result = listener.actionGet(TIMEOUT);
+
+            MatcherAssert.assertThat(
+                result,
+                // the dimension is set to 2 because there are 2 embeddings returned from the mock server
+                is(
+                    CohereEmbeddingsModelTests.createModel(
+                        getUrl(webServer),
+                        "secret",
+                        CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                        10,
+                        2,
+                        null,
+                        null,
+                        SimilarityMeasure.COSINE
+                    )
+                )
+            );
+        }
+    }
+
     public void testInfer_UnauthorisedResponse() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 

+ 23 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java

@@ -225,4 +225,27 @@ public class CohereEmbeddingsModelTests extends ESTestCase {
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
     }
+
+    public static CohereEmbeddingsModel createModel(
+        String url,
+        String apiKey,
+        CohereEmbeddingsTaskSettings taskSettings,
+        @Nullable Integer tokenLimit,
+        @Nullable Integer dimensions,
+        @Nullable String model,
+        @Nullable CohereEmbeddingType embeddingType,
+        @Nullable SimilarityMeasure similarityMeasure
+    ) {
+        return new CohereEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "service",
+            new CohereEmbeddingsServiceSettings(
+                new CohereServiceSettings(url, similarityMeasure, dimensions, tokenLimit, model),
+                Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT)
+            ),
+            taskSettings,
+            new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+        );
+    }
 }

+ 13 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

@@ -19,6 +19,7 @@ import org.elasticsearch.inference.InferenceServiceExtension;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
@@ -29,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResultsTests;
 import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
 import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;
 
 import java.util.ArrayList;
@@ -235,16 +237,17 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
             settings.put(
                 ModelConfigurations.SERVICE_SETTINGS,
                 new HashMap<>(
-                    Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4)
+                    Map.of(
+                        ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS,
+                        1,
+                        ElasticsearchInternalServiceSettings.NUM_THREADS,
+                        4,
+                        ServiceFields.SIMILARITY,
+                        SimilarityMeasure.L2_NORM.toString()
+                    )
                 )
             );
 
-            var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings(
-                1,
-                4,
-                ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID
-            );
-
             expectThrows(IllegalArgumentException.class, () -> service.parsePersistedConfig(randomInferenceEntityId, taskType, settings));
 
         }
@@ -290,7 +293,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
                         ElasticsearchInternalServiceSettings.NUM_THREADS,
                         4,
                         InternalServiceSettings.MODEL_ID,
-                        ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID
+                        ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID,
+                        ServiceFields.DIMENSIONS,
+                        1
                     )
                 )
             );

+ 7 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java

@@ -117,7 +117,13 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest
             () -> HuggingFaceServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.URL, url, ServiceFields.SIMILARITY, similarity)))
         );
 
-        assertThat(thrownException.getMessage(), is("Validation Failed: 1: [service_settings] Unknown similarity measure [by_size];"));
+        assertThat(
+            thrownException.getMessage(),
+            is(
+                "Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] "
+                    + "must be one of [cosine, dot_product, l2_norm];"
+            )
+        );
     }
 
     @Override

+ 54 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

@@ -21,6 +21,7 @@ import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.http.MockResponse;
@@ -512,6 +513,59 @@ public class HuggingFaceServiceTests extends ESTestCase {
         }
     }
 
+    public void testCheckModelConfig_UsesUserSpecifiedSimilarity() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "embeddings": [
+                        [
+                            -0.0123
+                        ]
+                    ]
+                {
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 2, SimilarityMeasure.COSINE);
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+            assertThat(
+                result,
+                is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE))
+            );
+        }
+    }
+
+    public void testCheckModelConfig_LeavesSimilarityAsNull_WhenUnspecified() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "embeddings": [
+                        [
+                            -0.0123
+                        ]
+                    ]
+                {
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 2, null);
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+            assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, null)));
+        }
+    }
+
     private HuggingFaceService createHuggingFaceService() {
         return new HuggingFaceService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
     }

+ 18 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java

@@ -8,6 +8,8 @@
 package org.elasticsearch.xpack.inference.services.huggingface.embeddings;
 
 import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
@@ -52,4 +54,20 @@ public class HuggingFaceEmbeddingsModelTests extends ESTestCase {
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
     }
+
+    public static HuggingFaceEmbeddingsModel createModel(
+        String url,
+        String apiKey,
+        int tokenLimit,
+        int dimensions,
+        @Nullable SimilarityMeasure similarityMeasure
+    ) {
+        return new HuggingFaceEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "service",
+            new HuggingFaceServiceSettings(createUri(url), similarityMeasure, dimensions, tokenLimit),
+            new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+        );
+    }
 }

+ 113 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

@@ -952,6 +952,118 @@ public class OpenAiServiceTests extends ESTestCase {
     public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 
+        try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 100, false);
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var returnedModel = listener.actionGet(TIMEOUT);
+            assertThat(
+                returnedModel,
+                is(
+                    OpenAiEmbeddingsModelTests.createModel(
+                        getUrl(webServer),
+                        "org",
+                        "secret",
+                        "model",
+                        "user",
+                        SimilarityMeasure.DOT_PRODUCT,
+                        100,
+                        2,
+                        false
+                    )
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user")));
+        }
+    }
+
+    public void testCheckModelConfig_ReturnsNewModelReference_SetsSimilarityToDocProduct_WhenNull() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 100, false);
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var returnedModel = listener.actionGet(TIMEOUT);
+            assertThat(
+                returnedModel,
+                is(
+                    OpenAiEmbeddingsModelTests.createModel(
+                        getUrl(webServer),
+                        "org",
+                        "secret",
+                        "model",
+                        "user",
+                        SimilarityMeasure.DOT_PRODUCT,
+                        100,
+                        2,
+                        false
+                    )
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user")));
+        }
+    }
+
+    public void testCheckModelConfig_ReturnsNewModelReference_DoesNotOverrideSimilarity_WhenNotNull() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
         try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
 
             String responseJson = """
@@ -1000,7 +1112,7 @@ public class OpenAiServiceTests extends ESTestCase {
                         "secret",
                         "model",
                         "user",
-                        SimilarityMeasure.DOT_PRODUCT,
+                        SimilarityMeasure.COSINE,
                         100,
                         2,
                         false

+ 7 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java

@@ -303,7 +303,13 @@ public class OpenAiEmbeddingsServiceSettingsTests extends AbstractWireSerializin
             )
         );
 
-        assertThat(thrownException.getMessage(), is("Validation Failed: 1: [service_settings] Unknown similarity measure [by_size];"));
+        assertThat(
+            thrownException.getMessage(),
+            is(
+                "Validation Failed: 1: [service_settings] Invalid value [by_size] received. [similarity] "
+                    + "must be one of [cosine, dot_product, l2_norm];"
+            )
+        );
     }
 
     public void testToXContent_WritesDimensionsSetByUserTrue() throws IOException {