浏览代码

Adding support for specifying embedding type to Jina AI service settings (#121548)

* Adding embeddings type to Jina AI service settings

* Update docs/changelog/121548.yaml

* Setting default similarity to L2 norm for binary embedding type
Ying Mao 8 月之前
父节点
当前提交
6b2e56697e
共有 14 个文件被更改,包括 911 次插入140 次删除
  1. 5 0
      docs/changelog/121548.yaml
  2. 2 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  3. 8 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java
  4. 12 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java
  5. 88 9
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java
  6. 8 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java
  7. 119 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java
  8. 50 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java
  9. 47 4
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java
  10. 105 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java
  11. 157 24
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java
  12. 91 52
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java
  13. 138 24
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java
  14. 81 8
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java

+ 5 - 0
docs/changelog/121548.yaml

@@ -0,0 +1,5 @@
+pr: 121548
+summary: Adding support for specifying embedding type to Jina AI service settings
+area: Machine Learning
+type: enhancement
+issues: []

+ 2 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -181,6 +181,7 @@ public class TransportVersions {
     public static final TransportVersion ESQL_RETRY_ON_SHARD_LEVEL_FAILURE_BACKPORT_8_19 = def(8_841_0_03);
     public static final TransportVersion ESQL_SUPPORT_PARTIAL_RESULTS_BACKPORT_8_19 = def(8_841_0_04);
     public static final TransportVersion VOYAGE_AI_INTEGRATION_ADDED_BACKPORT_8_X = def(8_841_0_05);
+    public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
     public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
     public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
     public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -207,6 +208,7 @@ public class TransportVersions {
     public static final TransportVersion ESQL_DRIVER_NODE_DESCRIPTION = def(9_017_0_00);
     public static final TransportVersion MULTI_PROJECT = def(9_018_0_00);
     public static final TransportVersion STORED_SCRIPT_CONTENT_LENGTH = def(9_019_0_00);
+    public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_020_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 8 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequest.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount;
 import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
 
@@ -30,6 +31,7 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest {
     private final JinaAIEmbeddingsTaskSettings taskSettings;
     private final String model;
     private final String inferenceEntityId;
+    private final JinaAIEmbeddingType embeddingType;
 
     public JinaAIEmbeddingsRequest(List<String> input, JinaAIEmbeddingsModel embeddingsModel) {
         Objects.requireNonNull(embeddingsModel);
@@ -38,6 +40,7 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest {
         this.input = Objects.requireNonNull(input);
         taskSettings = embeddingsModel.getTaskSettings();
         model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
+        embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
         inferenceEntityId = embeddingsModel.getInferenceEntityId();
     }
 
@@ -46,7 +49,7 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest {
         HttpPost httpPost = new HttpPost(account.uri());
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
+            Strings.toString(new JinaAIEmbeddingsRequestEntity(input, taskSettings, model, embeddingType)).getBytes(StandardCharsets.UTF_8)
         );
         httpPost.setEntity(byteEntity);
 
@@ -75,6 +78,10 @@ public class JinaAIEmbeddingsRequest extends JinaAIRequest {
         return null;
     }
 
+    public JinaAIEmbeddingType getEmbeddingType() {
+        return embeddingType;
+    }
+
     public static URI buildDefaultUri() throws URISyntaxException {
         return new URIBuilder().setScheme("https")
             .setHost(JinaAIUtils.HOST)

+ 12 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntity.java

@@ -11,6 +11,7 @@ import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
 
 import java.io.IOException;
@@ -19,9 +20,12 @@ import java.util.Objects;
 
 import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings.invalidInputTypeMessage;
 
-public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddingsTaskSettings taskSettings, @Nullable String model)
-    implements
-        ToXContentObject {
+public record JinaAIEmbeddingsRequestEntity(
+    List<String> input,
+    JinaAIEmbeddingsTaskSettings taskSettings,
+    @Nullable String model,
+    @Nullable JinaAIEmbeddingType embeddingType
+) implements ToXContentObject {
 
     private static final String SEARCH_DOCUMENT = "retrieval.passage";
     private static final String SEARCH_QUERY = "retrieval.query";
@@ -30,6 +34,7 @@ public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddings
     private static final String INPUT_FIELD = "input";
     private static final String MODEL_FIELD = "model";
     public static final String TASK_TYPE_FIELD = "task";
+    static final String EMBEDDING_TYPE_FIELD = "embedding_type";
 
     public JinaAIEmbeddingsRequestEntity {
         Objects.requireNonNull(input);
@@ -43,6 +48,10 @@ public record JinaAIEmbeddingsRequestEntity(List<String> input, JinaAIEmbeddings
         builder.field(INPUT_FIELD, input);
         builder.field(MODEL_FIELD, model);
 
+        if (embeddingType != null) {
+            builder.field(EMBEDDING_TYPE_FIELD, embeddingType.toRequestString());
+        }
+
         if (taskSettings.getInputType() != null) {
             builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType()));
         }

+ 88 - 9
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntity.java

@@ -9,28 +9,54 @@
 
 package org.elasticsearch.xpack.inference.external.response.jinaai;
 
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.core.CheckedFunction;
+import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentParserConfiguration;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
 import org.elasticsearch.xpack.inference.external.http.HttpResult;
 import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequest;
 import org.elasticsearch.xpack.inference.external.response.XContentUtils;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
 
 import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
 import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
 import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd;
 import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
 import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
+import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType.toLowerCase;
 
 public class JinaAIEmbeddingsResponseEntity {
     private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in JinaAI embeddings response";
 
+    private static final Map<String, CheckedFunction<XContentParser, InferenceServiceResults, IOException>> EMBEDDING_PARSERS = Map.of(
+        toLowerCase(JinaAIEmbeddingType.FLOAT),
+        JinaAIEmbeddingsResponseEntity::parseFloatDataObject,
+        toLowerCase(JinaAIEmbeddingType.BIT),
+        JinaAIEmbeddingsResponseEntity::parseBitDataObject,
+        toLowerCase(JinaAIEmbeddingType.BINARY),
+        JinaAIEmbeddingsResponseEntity::parseBitDataObject
+    );
+    private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
+
+    private static String supportedEmbeddingTypes() {
+        var validTypes = EMBEDDING_PARSERS.keySet().toArray(String[]::new);
+        Arrays.sort(validTypes);
+        return String.join(", ", validTypes);
+    }
+
     /**
      * Parses the JinaAI json response.
      * For a request like:
@@ -73,8 +99,21 @@ public class JinaAIEmbeddingsResponseEntity {
      * </code>
      * </pre>
      */
-    public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
+    public static InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
+        // embeddings type is not specified anywhere in the response so grab it from the request
+        JinaAIEmbeddingsRequest embeddingsRequest = (JinaAIEmbeddingsRequest) request;
+        var embeddingType = embeddingsRequest.getEmbeddingType().toString();
         var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+        var embeddingValueParser = EMBEDDING_PARSERS.get(embeddingType);
+
+        if (embeddingValueParser == null) {
+            throw new IllegalStateException(
+                Strings.format(
+                    "Failed to find a supported embedding type for in the Jina AI embeddings response. Supported types are [%s]",
+                    VALID_EMBEDDING_TYPES_STRING
+                )
+            );
+        }
 
         try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
             moveToFirstToken(jsonParser);
@@ -84,26 +123,66 @@ public class JinaAIEmbeddingsResponseEntity {
 
             positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE);
 
-            List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
-                jsonParser,
-                JinaAIEmbeddingsResponseEntity::parseEmbeddingObject
-            );
-
-            return new TextEmbeddingFloatResults(embeddingList);
+            return embeddingValueParser.apply(jsonParser);
         }
     }
 
-    private static TextEmbeddingFloatResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException {
+    private static InferenceServiceResults parseFloatDataObject(XContentParser jsonParser) throws IOException {
+        List<TextEmbeddingFloatResults.Embedding> embeddingList = parseList(
+            jsonParser,
+            JinaAIEmbeddingsResponseEntity::parseFloatEmbeddingObject
+        );
+
+        return new TextEmbeddingFloatResults(embeddingList);
+    }
+
+    private static TextEmbeddingFloatResults.Embedding parseFloatEmbeddingObject(XContentParser parser) throws IOException {
         ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
 
         positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
 
-        List<Float> embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
+        var embeddingValuesList = parseList(parser, XContentUtils::parseFloat);
         // parse and discard the rest of the object
         consumeUntilObjectEnd(parser);
 
         return TextEmbeddingFloatResults.Embedding.of(embeddingValuesList);
     }
 
+    private static InferenceServiceResults parseBitDataObject(XContentParser jsonParser) throws IOException {
+        List<TextEmbeddingByteResults.Embedding> embeddingList = parseList(
+            jsonParser,
+            JinaAIEmbeddingsResponseEntity::parseBitEmbeddingObject
+        );
+
+        return new TextEmbeddingBitResults(embeddingList);
+    }
+
+    private static TextEmbeddingByteResults.Embedding parseBitEmbeddingObject(XContentParser parser) throws IOException {
+        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
+
+        positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
+
+        var embeddingList = parseList(parser, JinaAIEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
+        // parse and discard the rest of the object
+        consumeUntilObjectEnd(parser);
+
+        return TextEmbeddingByteResults.Embedding.of(embeddingList);
+    }
+
+    private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {
+        XContentParser.Token token = parser.currentToken();
+        ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
+        var parsedByte = parser.shortValue();
+        checkByteBounds(parsedByte);
+
+        return (byte) parsedByte;
+    }
+
+    private static void checkByteBounds(short value) {
+        if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
+            throw new IllegalArgumentException("Value [" + value + "] is out of range for a byte");
+        }
+    }
+
     private JinaAIEmbeddingsResponseEntity() {}
 }

+ 8 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java

@@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.SenderService;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
@@ -294,7 +295,7 @@ public class JinaAIService extends SenderService {
         if (model instanceof JinaAIEmbeddingsModel embeddingsModel) {
             var serviceSettings = embeddingsModel.getServiceSettings();
             var similarityFromModel = serviceSettings.similarity();
-            var similarityToUse = similarityFromModel == null ? defaultSimilarity() : similarityFromModel;
+            var similarityToUse = similarityFromModel == null ? defaultSimilarity(serviceSettings.getEmbeddingType()) : similarityFromModel;
             var maxInputTokens = serviceSettings.maxInputTokens();
 
             var updatedServiceSettings = new JinaAIEmbeddingsServiceSettings(
@@ -305,7 +306,8 @@ public class JinaAIService extends SenderService {
                 ),
                 similarityToUse,
                 embeddingSize,
-                maxInputTokens
+                maxInputTokens,
+                serviceSettings.getEmbeddingType()
             );
 
             return new JinaAIEmbeddingsModel(embeddingsModel, updatedServiceSettings);
@@ -322,7 +324,10 @@ public class JinaAIService extends SenderService {
      *
      * @return The default similarity.
      */
-    static SimilarityMeasure defaultSimilarity() {
+    static SimilarityMeasure defaultSimilarity(JinaAIEmbeddingType embeddingType) {
+        if (embeddingType == JinaAIEmbeddingType.BINARY || embeddingType == JinaAIEmbeddingType.BIT) {
+            return SimilarityMeasure.L2_NORM;
+        }
         return SimilarityMeasure.DOT_PRODUCT;
     }
 

+ 119 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingType.java

@@ -0,0 +1,119 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.jinaai.embeddings;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+
+import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.Locale;
+import java.util.Map;
+
+/**
+ * Defines the type of embedding that the Jina AI API should return for a request.
+ *
+ */
+public enum JinaAIEmbeddingType {
+    /**
+     * Use this when you want to get back the default float embeddings.
+     */
+    FLOAT(DenseVectorFieldMapper.ElementType.FLOAT, RequestConstants.FLOAT),
+    /**
+     * Use this when you want to get back binary embeddings.
+     */
+    BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT),
+    /**
+     * This is a synonym for BIT
+     */
+    BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);
+
+    private static final class RequestConstants {
+        private static final String FLOAT = "float";
+        private static final String BIT = "binary";
+    }
+
+    private static final Map<DenseVectorFieldMapper.ElementType, JinaAIEmbeddingType> ELEMENT_TYPE_TO_JINA_AI_EMBEDDING = Map.of(
+        DenseVectorFieldMapper.ElementType.FLOAT,
+        FLOAT,
+        DenseVectorFieldMapper.ElementType.BIT,
+        BIT
+    );
+    static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
+        ELEMENT_TYPE_TO_JINA_AI_EMBEDDING.keySet()
+    );
+
+    private final DenseVectorFieldMapper.ElementType elementType;
+    private final String requestString;
+
+    JinaAIEmbeddingType(DenseVectorFieldMapper.ElementType elementType, String requestString) {
+        this.elementType = elementType;
+        this.requestString = requestString;
+    }
+
+    @Override
+    public String toString() {
+        return name().toLowerCase(Locale.ROOT);
+    }
+
+    public String toRequestString() {
+        return requestString;
+    }
+
+    public static String toLowerCase(JinaAIEmbeddingType type) {
+        return type.toString().toLowerCase(Locale.ROOT);
+    }
+
+    public static JinaAIEmbeddingType fromString(String name) {
+        return valueOf(name.trim().toUpperCase(Locale.ROOT));
+    }
+
+    public static JinaAIEmbeddingType fromElementType(DenseVectorFieldMapper.ElementType elementType) {
+        var embedding = ELEMENT_TYPE_TO_JINA_AI_EMBEDDING.get(elementType);
+
+        if (embedding == null) {
+            var validElementTypes = SUPPORTED_ELEMENT_TYPES.stream()
+                .map(value -> value.toString().toLowerCase(Locale.ROOT))
+                .toArray(String[]::new);
+            Arrays.sort(validElementTypes);
+
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "Element type [%s] does not map to a Jina AI embedding value, must be one of [%s]",
+                    elementType,
+                    String.join(", ", validElementTypes)
+                )
+            );
+        }
+
+        return embedding;
+    }
+
+    public DenseVectorFieldMapper.ElementType toElementType() {
+        return elementType;
+    }
+
+    /**
+     * Returns an embedding type that is known based on the transport version provided. If the embedding type enum was not yet
+     * introduced it will be defaulted FLOAT.
+     *
+     * @param embeddingType the value to translate if necessary
+     * @param version the version that dictates the translation
+     * @return the embedding type that is known to the version passed in
+     */
+    public static JinaAIEmbeddingType translateToVersion(JinaAIEmbeddingType embeddingType, TransportVersion version) {
+        if (version.onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)
+            || version.isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19)) {
+            return embeddingType;
+        }
+
+        return FLOAT;
+    }
+}

+ 50 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettings.java

@@ -23,18 +23,22 @@ import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings;
 import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
 
 import java.io.IOException;
+import java.util.EnumSet;
 import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
 import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
 import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
 
 public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings {
     public static final String NAME = "jinaai_embeddings_service_settings";
 
+    static final String EMBEDDING_TYPE = "embedding_type";
+
     public static JinaAIEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
         ValidationException validationException = new ValidationException();
         var commonServiceSettings = JinaAIServiceSettings.fromMap(map, context);
@@ -42,28 +46,47 @@ public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject impl
         Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
         Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
 
+        JinaAIEmbeddingType embeddingTypes = parseEmbeddingType(map, validationException);
+
         if (validationException.validationErrors().isEmpty() == false) {
             throw validationException;
         }
 
-        return new JinaAIEmbeddingsServiceSettings(commonServiceSettings, similarity, dims, maxInputTokens);
+        return new JinaAIEmbeddingsServiceSettings(commonServiceSettings, similarity, dims, maxInputTokens, embeddingTypes);
+    }
+
+    static JinaAIEmbeddingType parseEmbeddingType(Map<String, Object> map, ValidationException validationException) {
+        return Objects.requireNonNullElse(
+            extractOptionalEnum(
+                map,
+                EMBEDDING_TYPE,
+                ModelConfigurations.SERVICE_SETTINGS,
+                JinaAIEmbeddingType::fromString,
+                EnumSet.allOf(JinaAIEmbeddingType.class),
+                validationException
+            ),
+            JinaAIEmbeddingType.FLOAT
+        );
     }
 
     private final JinaAIServiceSettings commonSettings;
     private final SimilarityMeasure similarity;
     private final Integer dimensions;
     private final Integer maxInputTokens;
+    private final JinaAIEmbeddingType embeddingType;
 
     public JinaAIEmbeddingsServiceSettings(
         JinaAIServiceSettings commonSettings,
         @Nullable SimilarityMeasure similarity,
         @Nullable Integer dimensions,
-        @Nullable Integer maxInputTokens
+        @Nullable Integer maxInputTokens,
+        @Nullable JinaAIEmbeddingType embeddingType
     ) {
         this.commonSettings = commonSettings;
         this.similarity = similarity;
         this.dimensions = dimensions;
         this.maxInputTokens = maxInputTokens;
+        this.embeddingType = embeddingType != null ? embeddingType : JinaAIEmbeddingType.FLOAT;
     }
 
     public JinaAIEmbeddingsServiceSettings(StreamInput in) throws IOException {
@@ -71,6 +94,11 @@ public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject impl
         this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
         this.dimensions = in.readOptionalVInt();
         this.maxInputTokens = in.readOptionalVInt();
+
+        this.embeddingType = (in.getTransportVersion().onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)
+            || in.getTransportVersion().isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19))
+                ? Objects.requireNonNullElse(in.readOptionalEnum(JinaAIEmbeddingType.class), JinaAIEmbeddingType.FLOAT)
+                : JinaAIEmbeddingType.FLOAT;
     }
 
     public JinaAIServiceSettings getCommonSettings() {
@@ -96,9 +124,13 @@ public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject impl
         return commonSettings.modelId();
     }
 
+    public JinaAIEmbeddingType getEmbeddingType() {
+        return embeddingType;
+    }
+
     @Override
     public DenseVectorFieldMapper.ElementType elementType() {
-        return DenseVectorFieldMapper.ElementType.FLOAT;
+        return embeddingType == null ? DenseVectorFieldMapper.ElementType.FLOAT : embeddingType.toElementType();
     }
 
     @Override
@@ -120,6 +152,10 @@ public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject impl
         if (maxInputTokens != null) {
             builder.field(MAX_INPUT_TOKENS, maxInputTokens);
         }
+        if (embeddingType != null) {
+            builder.field(EMBEDDING_TYPE, embeddingType);
+        }
+
         builder.endObject();
         return builder;
     }
@@ -127,7 +163,9 @@ public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject impl
     @Override
     protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
         commonSettings.toXContentFragmentOfExposedFields(builder, params);
-
+        if (embeddingType != null) {
+            builder.field(EMBEDDING_TYPE, embeddingType);
+        }
         return builder;
     }
 
@@ -142,6 +180,11 @@ public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject impl
         out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
         out.writeOptionalVInt(dimensions);
         out.writeOptionalVInt(maxInputTokens);
+
+        if (out.getTransportVersion().onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)
+            || out.getTransportVersion().isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19)) {
+            out.writeOptionalEnum(JinaAIEmbeddingType.translateToVersion(embeddingType, out.getTransportVersion()));
+        }
     }
 
     @Override
@@ -152,11 +195,12 @@ public class JinaAIEmbeddingsServiceSettings extends FilteredXContentObject impl
         return Objects.equals(commonSettings, that.commonSettings)
             && Objects.equals(similarity, that.similarity)
             && Objects.equals(dimensions, that.dimensions)
-            && Objects.equals(maxInputTokens, that.maxInputTokens);
+            && Objects.equals(maxInputTokens, that.maxInputTokens)
+            && Objects.equals(embeddingType, that.embeddingType);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens);
+        return Objects.hash(commonSettings, similarity, dimensions, maxInputTokens, embeddingType);
     }
 }

+ 47 - 4
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestEntityTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
 import org.hamcrest.MatcherAssert;
 
@@ -23,25 +24,67 @@ import static org.hamcrest.CoreMatchers.is;
 
 public class JinaAIEmbeddingsRequestEntityTests extends ESTestCase {
     public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
-        var entity = new JinaAIEmbeddingsRequestEntity(List.of("abc"), new JinaAIEmbeddingsTaskSettings(InputType.INGEST), "model");
+        var entity = new JinaAIEmbeddingsRequestEntity(
+            List.of("abc"),
+            new JinaAIEmbeddingsTaskSettings(InputType.INGEST),
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
         String xContentResult = Strings.toString(builder);
 
         MatcherAssert.assertThat(xContentResult, is("""
-            {"input":["abc"],"model":"model","task":"retrieval.passage"}"""));
+            {"input":["abc"],"model":"model","embedding_type":"float","task":"retrieval.passage"}"""));
     }
 
     public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
-        var entity = new JinaAIEmbeddingsRequestEntity(List.of("abc"), JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, "model");
+        var entity = new JinaAIEmbeddingsRequestEntity(
+            List.of("abc"),
+            JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        MatcherAssert.assertThat(xContentResult, is("""
+            {"input":["abc"],"model":"model","embedding_type":"float"}"""));
+    }
+
+    public void testXContent_EmbeddingTypesBit() throws IOException {
+        var entity = new JinaAIEmbeddingsRequestEntity(
+            List.of("abc"),
+            JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+            "model",
+            JinaAIEmbeddingType.BIT
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        MatcherAssert.assertThat(xContentResult, is("""
+            {"input":["abc"],"model":"model","embedding_type":"binary"}"""));
+    }
+
+    public void testXContent_EmbeddingTypesBinary() throws IOException {
+        var entity = new JinaAIEmbeddingsRequestEntity(
+            List.of("abc"),
+            JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+            "model",
+            JinaAIEmbeddingType.BINARY
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
         String xContentResult = Strings.toString(builder);
 
         MatcherAssert.assertThat(xContentResult, is("""
-            {"input":["abc"],"model":"model"}"""));
+            {"input":["abc"],"model":"model","embedding_type":"binary"}"""));
     }
 
     public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {

+ 105 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIEmbeddingsRequestTests.java

@@ -12,6 +12,7 @@ import org.apache.http.client.methods.HttpPost;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
@@ -29,7 +30,15 @@ public class JinaAIEmbeddingsRequestTests extends ESTestCase {
     public void testCreateRequest_UrlDefined() throws IOException {
         var request = createRequest(
             List.of("abc"),
-            JinaAIEmbeddingsModelTests.createModel("url", "secret", JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, "model")
+            JinaAIEmbeddingsModelTests.createModel(
+                "url",
+                "secret",
+                JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                null,
+                null,
+                "model",
+                JinaAIEmbeddingType.FLOAT
+            )
         );
 
         var httpRequest = request.createHttpRequest();
@@ -46,13 +55,21 @@ public class JinaAIEmbeddingsRequestTests extends ESTestCase {
         );
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model")));
+        MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "embedding_type", "float")));
     }
 
     public void testCreateRequest_AllOptionsDefined() throws IOException {
         var request = createRequest(
             List.of("abc"),
-            JinaAIEmbeddingsModelTests.createModel("url", "secret", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model")
+            JinaAIEmbeddingsModelTests.createModel(
+                "url",
+                "secret",
+                new JinaAIEmbeddingsTaskSettings(InputType.INGEST),
+                null,
+                null,
+                "model",
+                JinaAIEmbeddingType.FLOAT
+            )
         );
 
         var httpRequest = request.createHttpRequest();
@@ -69,13 +86,58 @@ public class JinaAIEmbeddingsRequestTests extends ESTestCase {
         );
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.passage")));
+        MatcherAssert.assertThat(
+            requestMap,
+            is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.passage", "embedding_type", "float"))
+        );
     }
 
     public void testCreateRequest_InputTypeSearch() throws IOException {
         var request = createRequest(
             List.of("abc"),
-            JinaAIEmbeddingsModelTests.createModel("url", "secret", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model")
+            JinaAIEmbeddingsModelTests.createModel(
+                "url",
+                "secret",
+                new JinaAIEmbeddingsTaskSettings(InputType.SEARCH),
+                null,
+                null,
+                "model",
+                JinaAIEmbeddingType.FLOAT
+            )
+        );
+
+        var httpRequest = request.createHttpRequest();
+        MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
+        MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+        MatcherAssert.assertThat(
+            httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(),
+            is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)
+        );
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        MatcherAssert.assertThat(
+            requestMap,
+            is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query", "embedding_type", "float"))
+        );
+    }
+
+    public void testCreateRequest_EmbeddingTypeBit() throws IOException {
+        var request = createRequest(
+            List.of("abc"),
+            JinaAIEmbeddingsModelTests.createModel(
+                "url",
+                "secret",
+                new JinaAIEmbeddingsTaskSettings(InputType.SEARCH),
+                null,
+                null,
+                "model",
+                JinaAIEmbeddingType.BIT
+            )
         );
 
         var httpRequest = request.createHttpRequest();
@@ -92,7 +154,44 @@ public class JinaAIEmbeddingsRequestTests extends ESTestCase {
         );
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query")));
+        MatcherAssert.assertThat(
+            requestMap,
+            is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query", "embedding_type", "binary"))
+        );
+    }
+
+    public void testCreateRequest_EmbeddingTypeBinary() throws IOException {
+        var request = createRequest(
+            List.of("abc"),
+            JinaAIEmbeddingsModelTests.createModel(
+                "url",
+                "secret",
+                new JinaAIEmbeddingsTaskSettings(InputType.SEARCH),
+                null,
+                null,
+                "model",
+                JinaAIEmbeddingType.BINARY
+            )
+        );
+
+        var httpRequest = request.createHttpRequest();
+        MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
+        MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+        MatcherAssert.assertThat(
+            httpPost.getLastHeader(JinaAIUtils.REQUEST_SOURCE_HEADER).getValue(),
+            is(JinaAIUtils.ELASTIC_REQUEST_SOURCE)
+        );
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        MatcherAssert.assertThat(
+            requestMap,
+            is(Map.of("input", List.of("abc"), "model", "model", "task", "retrieval.query", "embedding_type", "binary"))
+        );
     }
 
     public static JinaAIEmbeddingsRequest createRequest(List<String> input, JinaAIEmbeddingsModel model) {

+ 157 - 24
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/jinaai/JinaAIEmbeddingsResponseEntityTests.java

@@ -9,15 +9,22 @@ package org.elasticsearch.xpack.inference.external.response.jinaai;
 
 import org.apache.http.HttpResponse;
 import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
 import org.elasticsearch.xpack.inference.external.http.HttpResult;
-import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequestTests;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.util.List;
 
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.mockito.Mockito.mock;
 
@@ -44,13 +51,25 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
             }
             """;
 
-        TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
-            mock(Request.class),
+        InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
+            JinaAIEmbeddingsRequestTests.createRequest(
+                List.of("abc"),
+                JinaAIEmbeddingsModelTests.createModel(
+                    "url",
+                    "secret",
+                    JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                    null,
+                    null,
+                    "model",
+                    JinaAIEmbeddingType.FLOAT
+                )
+            ),
             new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
         );
 
+        assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
         assertThat(
-            parsedResults.embeddings(),
+            ((TextEmbeddingFloatResults) parsedResults).embeddings(),
             is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
         );
     }
@@ -85,13 +104,25 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
             }
             """;
 
-        TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
-            mock(Request.class),
+        InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
+            JinaAIEmbeddingsRequestTests.createRequest(
+                List.of("abc"),
+                JinaAIEmbeddingsModelTests.createModel(
+                    "url",
+                    "secret",
+                    JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                    null,
+                    null,
+                    "model",
+                    JinaAIEmbeddingType.FLOAT
+                )
+            ),
             new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
         );
 
+        assertThat(parsedResults, instanceOf(TextEmbeddingFloatResults.class));
         assertThat(
-            parsedResults.embeddings(),
+            ((TextEmbeddingFloatResults) parsedResults).embeddings(),
             is(
                 List.of(
                     new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }),
@@ -126,7 +157,18 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
         var thrownException = expectThrows(
             IllegalStateException.class,
             () -> JinaAIEmbeddingsResponseEntity.fromResponse(
-                mock(Request.class),
+                JinaAIEmbeddingsRequestTests.createRequest(
+                    List.of("abc"),
+                    JinaAIEmbeddingsModelTests.createModel(
+                        "url",
+                        "secret",
+                        JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                        null,
+                        null,
+                        "model",
+                        JinaAIEmbeddingType.FLOAT
+                    )
+                ),
                 new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
@@ -159,7 +201,18 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
         var thrownException = expectThrows(
             ParsingException.class,
             () -> JinaAIEmbeddingsResponseEntity.fromResponse(
-                mock(Request.class),
+                JinaAIEmbeddingsRequestTests.createRequest(
+                    List.of("abc"),
+                    JinaAIEmbeddingsModelTests.createModel(
+                        "url",
+                        "secret",
+                        JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                        null,
+                        null,
+                        "model",
+                        JinaAIEmbeddingType.FLOAT
+                    )
+                ),
                 new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
@@ -195,7 +248,18 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
         var thrownException = expectThrows(
             IllegalStateException.class,
             () -> JinaAIEmbeddingsResponseEntity.fromResponse(
-                mock(Request.class),
+                JinaAIEmbeddingsRequestTests.createRequest(
+                    List.of("abc"),
+                    JinaAIEmbeddingsModelTests.createModel(
+                        "url",
+                        "secret",
+                        JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                        null,
+                        null,
+                        "model",
+                        JinaAIEmbeddingType.FLOAT
+                    )
+                ),
                 new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
@@ -227,7 +291,18 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
         var thrownException = expectThrows(
             ParsingException.class,
             () -> JinaAIEmbeddingsResponseEntity.fromResponse(
-                mock(Request.class),
+                JinaAIEmbeddingsRequestTests.createRequest(
+                    List.of("abc"),
+                    JinaAIEmbeddingsModelTests.createModel(
+                        "url",
+                        "secret",
+                        JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                        null,
+                        null,
+                        "model",
+                        JinaAIEmbeddingType.FLOAT
+                    )
+                ),
                 new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
@@ -238,7 +313,7 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
         );
     }
 
-    public void testFromResponse_SucceedsWhenEmbeddingValueIsInt() throws IOException {
+    public void testFromResponse_SucceedsWhenEmbeddingType_IsBinary() throws IOException {
         String responseJson = """
             {
               "object": "list",
@@ -247,7 +322,11 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
                       "object": "embedding",
                       "index": 0,
                       "embedding": [
-                          1
+                           -55,
+                            74,
+                            101,
+                            67,
+                            83
                       ]
                   }
               ],
@@ -259,15 +338,29 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
             }
             """;
 
-        TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
-            mock(Request.class),
+        InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
+            JinaAIEmbeddingsRequestTests.createRequest(
+                List.of("abc"),
+                JinaAIEmbeddingsModelTests.createModel(
+                    "url",
+                    "secret",
+                    JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                    null,
+                    null,
+                    "model",
+                    JinaAIEmbeddingType.BINARY
+                )
+            ),
             new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
         );
 
-        assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 1.0F }))));
+        assertThat(
+            ((TextEmbeddingBitResults) parsedResults).embeddings(),
+            is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
+        );
     }
 
-    public void testFromResponse_SucceedsWhenEmbeddingValueIsLong() throws IOException {
+    public void testFromResponse_SucceedsWhenEmbeddingType_IsBit() throws IOException {
         String responseJson = """
             {
               "object": "list",
@@ -276,7 +369,11 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
                       "object": "embedding",
                       "index": 0,
                       "embedding": [
-                          40294967295
+                           -55,
+                            74,
+                            101,
+                            67,
+                            83
                       ]
                   }
               ],
@@ -288,12 +385,26 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
             }
             """;
 
-        TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
-            mock(Request.class),
+        InferenceServiceResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
+            JinaAIEmbeddingsRequestTests.createRequest(
+                List.of("abc"),
+                JinaAIEmbeddingsModelTests.createModel(
+                    "url",
+                    "secret",
+                    JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                    null,
+                    null,
+                    "model",
+                    JinaAIEmbeddingType.BIT
+                )
+            ),
             new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
         );
 
-        assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 4.0294965E10F }))));
+        assertThat(
+            ((TextEmbeddingBitResults) parsedResults).embeddings(),
+            is(List.of(new TextEmbeddingByteResults.Embedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
+        );
     }
 
     public void testFromResponse_FailsWhenEmbeddingValueIsAnObject() {
@@ -320,7 +431,18 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
         var thrownException = expectThrows(
             ParsingException.class,
             () -> JinaAIEmbeddingsResponseEntity.fromResponse(
-                mock(Request.class),
+                JinaAIEmbeddingsRequestTests.createRequest(
+                    List.of("abc"),
+                    JinaAIEmbeddingsModelTests.createModel(
+                        "url",
+                        "secret",
+                        JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                        null,
+                        null,
+                        "model",
+                        JinaAIEmbeddingType.BINARY
+                    )
+                ),
                 new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
@@ -372,8 +494,19 @@ public class JinaAIEmbeddingsResponseEntityTests extends ESTestCase {
                 }
             }""";
 
-        TextEmbeddingFloatResults parsedResults = JinaAIEmbeddingsResponseEntity.fromResponse(
-            mock(Request.class),
+        TextEmbeddingFloatResults parsedResults = (TextEmbeddingFloatResults) JinaAIEmbeddingsResponseEntity.fromResponse(
+            JinaAIEmbeddingsRequestTests.createRequest(
+                List.of("abc"),
+                JinaAIEmbeddingsModelTests.createModel(
+                    "url",
+                    "secret",
+                    JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
+                    null,
+                    null,
+                    "model",
+                    JinaAIEmbeddingType.FLOAT
+                )
+            ),
             new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8))
         );
 

+ 91 - 52
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java

@@ -42,6 +42,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModelTests;
 import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettingsTests;
@@ -112,6 +113,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 var embeddingsModel = (JinaAIEmbeddingsModel) model;
                 MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
                 MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
+                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.FLOAT));
                 MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST)));
                 MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
             }, e -> fail("Model parsing should have succeeded " + e.getMessage()));
@@ -120,7 +122,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 "id",
                 TaskType.TEXT_EMBEDDING,
                 getRequestConfigMap(
-                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                     JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST),
                     getSecretSettingsMap("secret")
                 ),
@@ -138,6 +140,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 var embeddingsModel = (JinaAIEmbeddingsModel) model;
                 MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
                 MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
+                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.FLOAT));
                 MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST)));
                 MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
                 assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
@@ -148,7 +151,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 "id",
                 TaskType.TEXT_EMBEDDING,
                 getRequestConfigMap(
-                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                     JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST),
                     createRandomChunkingSettingsMap(),
                     getSecretSettingsMap("secret")
@@ -167,6 +170,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 var embeddingsModel = (JinaAIEmbeddingsModel) model;
                 MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
                 MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
+                MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(JinaAIEmbeddingType.BIT));
                 MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new JinaAIEmbeddingsTaskSettings(InputType.INGEST)));
                 MatcherAssert.assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
                 assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
@@ -177,7 +181,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 "id",
                 TaskType.TEXT_EMBEDDING,
                 getRequestConfigMap(
-                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.BIT),
                     JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST),
                     getSecretSettingsMap("secret")
                 ),
@@ -204,7 +208,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 "id",
                 TaskType.TEXT_EMBEDDING,
                 getRequestConfigMap(
-                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                     getSecretSettingsMap("secret")
                 ),
                 modelListener
@@ -224,7 +228,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 "id",
                 TaskType.SPARSE_EMBEDDING,
                 getRequestConfigMap(
-                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                     JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(),
                     getSecretSettingsMap("secret")
                 ),
@@ -243,7 +247,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
         try (var service = createJinaAIService()) {
             var config = getRequestConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(),
                 getSecretSettingsMap("secret")
             );
@@ -259,7 +263,7 @@ public class JinaAIServiceTests extends ESTestCase {
 
     public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
         try (var service = createJinaAIService()) {
-            var serviceSettings = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model");
+            var serviceSettings = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT);
             serviceSettings.put("extra_key", "value");
 
             var config = getRequestConfigMap(
@@ -282,7 +286,7 @@ public class JinaAIServiceTests extends ESTestCase {
             taskSettingsMap.put("extra_key", "value");
 
             var config = getRequestConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 taskSettingsMap,
                 getSecretSettingsMap("secret")
             );
@@ -302,7 +306,7 @@ public class JinaAIServiceTests extends ESTestCase {
             secretSettingsMap.put("extra_key", "value");
 
             var config = getRequestConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(),
                 secretSettingsMap
             );
@@ -330,7 +334,7 @@ public class JinaAIServiceTests extends ESTestCase {
                 "id",
                 TaskType.TEXT_EMBEDDING,
                 getRequestConfigMap(
-                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"),
+                    JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model", JinaAIEmbeddingType.FLOAT),
                     JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(),
                     getSecretSettingsMap("secret")
                 ),
@@ -343,7 +347,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModel() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
                 getSecretSettingsMap("secret")
             );
@@ -368,7 +372,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
                 createRandomChunkingSettingsMap(),
                 getSecretSettingsMap("secret")
@@ -395,7 +399,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
                 getSecretSettingsMap("secret")
             );
@@ -421,7 +425,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "oldmodel"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "oldmodel", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(),
                 getSecretSettingsMap("secret")
             );
@@ -446,7 +450,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfigWithSecrets_CreatesAJinaAIEmbeddingsModelWithoutUrl() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.INGEST),
                 getSecretSettingsMap("secret")
             );
@@ -471,7 +475,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH),
                 getSecretSettingsMap("secret")
             );
@@ -500,7 +504,7 @@ public class JinaAIServiceTests extends ESTestCase {
             secretSettingsMap.put("extra_key", "value");
 
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty(),
                 secretSettingsMap
             );
@@ -525,7 +529,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
                 getSecretSettingsMap("secret")
             );
@@ -550,7 +554,7 @@ public class JinaAIServiceTests extends ESTestCase {
 
     public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
         try (var service = createJinaAIService()) {
-            var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model");
+            var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT);
             serviceSettingsMap.put("extra_key", "value");
 
             var persistedConfig = getPersistedConfigMap(
@@ -582,7 +586,7 @@ public class JinaAIServiceTests extends ESTestCase {
             taskSettingsMap.put("extra_key", "value");
 
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 taskSettingsMap,
                 getSecretSettingsMap("secret")
             );
@@ -607,7 +611,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModel() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null)
             );
 
@@ -626,7 +630,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null),
                 createRandomChunkingSettingsMap()
             );
@@ -647,7 +651,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null)
             );
 
@@ -667,7 +671,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model_old"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model_old", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty()
             );
 
@@ -686,7 +690,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfig_CreatesAJinaAIEmbeddingsModelWithoutUrl() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMap(null)
             );
 
@@ -705,7 +709,7 @@ public class JinaAIServiceTests extends ESTestCase {
     public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
         try (var service = createJinaAIService()) {
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 JinaAIEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty()
             );
             persistedConfig.config().put("extra_key", "value");
@@ -724,7 +728,7 @@ public class JinaAIServiceTests extends ESTestCase {
 
     public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
         try (var service = createJinaAIService()) {
-            var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model");
+            var serviceSettingsMap = JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT);
             serviceSettingsMap.put("extra_key", "value");
 
             var persistedConfig = getPersistedConfigMap(
@@ -750,7 +754,7 @@ public class JinaAIServiceTests extends ESTestCase {
             taskSettingsMap.put("extra_key", "value");
 
             var persistedConfig = getPersistedConfigMap(
-                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model"),
+                JinaAIEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", "model", JinaAIEmbeddingType.FLOAT),
                 taskSettingsMap
             );
 
@@ -834,7 +838,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
                 10,
                 1,
-                "jina-clip-v2"
+                "jina-clip-v2",
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<Model> listener = new PlainActionFuture<>();
             service.checkModelConfig(model, listener);
@@ -850,7 +855,8 @@ public class JinaAIServiceTests extends ESTestCase {
                         JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS,
                         10,
                         2,
-                        "jina-clip-v2"
+                        "jina-clip-v2",
+                        JinaAIEmbeddingType.FLOAT
                     )
                 )
             );
@@ -891,7 +897,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 10,
                 1,
                 "jina-clip-v2",
-                null
+                null,
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<Model> listener = new PlainActionFuture<>();
             service.checkModelConfig(model, listener);
@@ -908,7 +915,8 @@ public class JinaAIServiceTests extends ESTestCase {
                         10,
                         2,
                         "jina-clip-v2",
-                        SimilarityMeasure.DOT_PRODUCT
+                        SimilarityMeasure.DOT_PRODUCT,
+                        JinaAIEmbeddingType.FLOAT
                     )
                 )
             );
@@ -949,7 +957,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 10,
                 1,
                 "jina-clip-v2",
-                SimilarityMeasure.COSINE
+                SimilarityMeasure.COSINE,
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<Model> listener = new PlainActionFuture<>();
             service.checkModelConfig(model, listener);
@@ -966,7 +975,8 @@ public class JinaAIServiceTests extends ESTestCase {
                         10,
                         2,
                         "jina-clip-v2",
-                        SimilarityMeasure.COSINE
+                        SimilarityMeasure.COSINE,
+                        JinaAIEmbeddingType.FLOAT
                     )
                 )
             );
@@ -986,6 +996,7 @@ public class JinaAIServiceTests extends ESTestCase {
 
         try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) {
             var embeddingSize = randomNonNegativeInt();
+            var embeddingType = randomFrom(JinaAIEmbeddingType.values());
             var model = JinaAIEmbeddingsModelTests.createModel(
                 randomAlphaOfLength(10),
                 randomAlphaOfLength(10),
@@ -993,12 +1004,15 @@ public class JinaAIServiceTests extends ESTestCase {
                 randomNonNegativeInt(),
                 randomNonNegativeInt(),
                 randomAlphaOfLength(10),
-                similarityMeasure
+                similarityMeasure,
+                embeddingType
             );
 
             Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
 
-            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? JinaAIService.defaultSimilarity() : similarityMeasure;
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null
+                ? JinaAIService.defaultSimilarity(embeddingType)
+                : similarityMeasure;
             assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
             assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
         }
@@ -1023,7 +1037,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 1024,
                 1024,
                 "model",
-                null
+                null,
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             service.infer(
@@ -1110,7 +1125,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 1024,
                 1024,
                 "jina-clip-v2",
-                null
+                null,
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             service.infer(
@@ -1137,7 +1153,10 @@ public class JinaAIServiceTests extends ESTestCase {
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.passage")));
+            MatcherAssert.assertThat(
+                requestMap,
+                is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.passage", "embedding_type", "float"))
+            );
         }
     }
 
@@ -1175,7 +1194,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 1024,
                 1024,
                 "jina-clip-v2",
-                null
+                null,
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             service.infer(
@@ -1202,7 +1222,10 @@ public class JinaAIServiceTests extends ESTestCase {
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.query")));
+            MatcherAssert.assertThat(
+                requestMap,
+                is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "retrieval.query", "embedding_type", "float"))
+            );
         }
     }
 
@@ -1224,7 +1247,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 1024,
                 1024,
                 "jina-clip-v2",
-                null
+                null,
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             service.infer(
@@ -1251,7 +1275,10 @@ public class JinaAIServiceTests extends ESTestCase {
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "separation")));
+            MatcherAssert.assertThat(
+                requestMap,
+                is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "task", "separation", "embedding_type", "float"))
+            );
         }
     }
 
@@ -1289,7 +1316,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 1024,
                 1024,
                 "jina-clip-v2",
-                null
+                null,
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
@@ -1307,7 +1335,7 @@ public class JinaAIServiceTests extends ESTestCase {
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2")));
+            MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "embedding_type", "float")));
         }
     }
 
@@ -1689,7 +1717,8 @@ public class JinaAIServiceTests extends ESTestCase {
                 1024,
                 1024,
                 "jina-clip-v2",
-                null
+                null,
+                JinaAIEmbeddingType.FLOAT
             );
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
             service.infer(
@@ -1715,7 +1744,7 @@ public class JinaAIServiceTests extends ESTestCase {
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2")));
+            MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("abc"), "model", "jina-clip-v2", "embedding_type", "float")));
         }
     }
 
@@ -1727,7 +1756,8 @@ public class JinaAIServiceTests extends ESTestCase {
             createRandomChunkingSettings(),
             1024,
             1024,
-            "jina-clip-v2"
+            "jina-clip-v2",
+            JinaAIEmbeddingType.FLOAT
         );
 
         test_Embedding_ChunkedInfer_BatchesCalls(model);
@@ -1741,7 +1771,8 @@ public class JinaAIServiceTests extends ESTestCase {
             null,
             1024,
             1024,
-            "jina-clip-v2"
+            "jina-clip-v2",
+            JinaAIEmbeddingType.FLOAT
         );
 
         test_Embedding_ChunkedInfer_BatchesCalls(model);
@@ -1831,12 +1862,20 @@ public class JinaAIServiceTests extends ESTestCase {
             MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            MatcherAssert.assertThat(requestMap, is(Map.of("input", List.of("foo", "bar"), "model", "jina-clip-v2")));
+            MatcherAssert.assertThat(
+                requestMap,
+                is(Map.of("input", List.of("foo", "bar"), "model", "jina-clip-v2", "embedding_type", "float"))
+            );
         }
     }
 
-    public void testDefaultSimilarity() {
-        assertEquals(SimilarityMeasure.DOT_PRODUCT, JinaAIService.defaultSimilarity());
+    public void testDefaultSimilarity_BinaryEmbedding() {
+        assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BINARY));
+        assertEquals(SimilarityMeasure.L2_NORM, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.BIT));
+    }
+
+    public void testDefaultSimilarity_NotBinaryEmbedding() {
+        assertEquals(SimilarityMeasure.DOT_PRODUCT, JinaAIService.defaultSimilarity(JinaAIEmbeddingType.FLOAT));
     }
 
     @SuppressWarnings("checkstyle:LineLength")

+ 138 - 24
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsModelTests.java

@@ -25,69 +25,171 @@ import static org.hamcrest.Matchers.is;
 public class JinaAIEmbeddingsModelTests extends ESTestCase {
 
     public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() {
-        var model = createModel("url", "api_key", null, null, "model");
+        var model = createModel("url", "api_key", null, null, "model", JinaAIEmbeddingType.FLOAT);
 
         var overriddenModel = JinaAIEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED);
         MatcherAssert.assertThat(overriddenModel, is(model));
     }
 
     public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() {
-        var model = createModel("url", "api_key", null, null, "model");
+        var model = createModel("url", "api_key", null, null, "model", JinaAIEmbeddingType.FLOAT);
 
         var overriddenModel = JinaAIEmbeddingsModel.of(model, null, InputType.UNSPECIFIED);
         MatcherAssert.assertThat(overriddenModel, is(model));
     }
 
     public void testOverrideWith_SetsInputTypeToIngest_WhenTheFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() {
-        var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model");
+        var model = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings((InputType) null),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
 
         var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.INGEST);
-        var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model");
+        var expectedModel = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings(InputType.INGEST),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
         MatcherAssert.assertThat(overriddenModel, is(expectedModel));
     }
 
     public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() {
-        var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model");
+        var model = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings(InputType.INGEST),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
 
         var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.SEARCH);
-        var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model");
+        var expectedModel = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings(InputType.SEARCH),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
         MatcherAssert.assertThat(overriddenModel, is(expectedModel));
     }
 
     public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() {
-        var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model");
+        var model = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings((InputType) null),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
 
         var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.INGEST), InputType.SEARCH);
-        var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model");
+        var expectedModel = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings(InputType.SEARCH),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
         MatcherAssert.assertThat(overriddenModel, is(expectedModel));
     }
 
     public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() {
-        var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model");
+        var model = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings(InputType.INGEST),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
 
         var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(InputType.SEARCH), InputType.UNSPECIFIED);
-        var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.SEARCH), null, null, "model");
+        var expectedModel = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings(InputType.SEARCH),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
         MatcherAssert.assertThat(overriddenModel, is(expectedModel));
     }
 
     public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() {
-        var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model");
+        var model = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings((InputType) null),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
 
         var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED);
-        var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings((InputType) null), null, null, "model");
+        var expectedModel = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings((InputType) null),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
         MatcherAssert.assertThat(overriddenModel, is(expectedModel));
     }
 
     public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() {
-        var model = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model");
+        var model = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings(InputType.INGEST),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
 
         var overriddenModel = JinaAIEmbeddingsModel.of(model, getTaskSettingsMap(null), InputType.UNSPECIFIED);
-        var expectedModel = createModel("url", "api_key", new JinaAIEmbeddingsTaskSettings(InputType.INGEST), null, null, "model");
+        var expectedModel = createModel(
+            "url",
+            "api_key",
+            new JinaAIEmbeddingsTaskSettings(InputType.INGEST),
+            null,
+            null,
+            "model",
+            JinaAIEmbeddingType.FLOAT
+        );
         MatcherAssert.assertThat(overriddenModel, is(expectedModel));
     }
 
-    public static JinaAIEmbeddingsModel createModel(String url, String apiKey, @Nullable Integer tokenLimit, @Nullable String model) {
-        return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model);
+    public static JinaAIEmbeddingsModel createModel(
+        String url,
+        String apiKey,
+        @Nullable Integer tokenLimit,
+        @Nullable String model,
+        @Nullable JinaAIEmbeddingType embeddingType
+    ) {
+        return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, null, model, embeddingType);
     }
 
     public static JinaAIEmbeddingsModel createModel(
@@ -95,9 +197,10 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase {
         String apiKey,
         @Nullable Integer tokenLimit,
         @Nullable Integer dimensions,
-        String model
+        String model,
+        @Nullable JinaAIEmbeddingType embeddingType
     ) {
-        return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model);
+        return createModel(url, apiKey, JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, tokenLimit, dimensions, model, embeddingType);
     }
 
     public static JinaAIEmbeddingsModel createModel(
@@ -107,7 +210,8 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase {
         ChunkingSettings chunkingSettings,
         @Nullable Integer tokenLimit,
         @Nullable Integer dimensions,
-        String model
+        String model,
+        @Nullable JinaAIEmbeddingType embeddingType
     ) {
         return new JinaAIEmbeddingsModel(
             "id",
@@ -116,7 +220,8 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase {
                 new JinaAIServiceSettings(url, model, null),
                 SimilarityMeasure.DOT_PRODUCT,
                 dimensions,
-                tokenLimit
+                tokenLimit,
+                embeddingType
             ),
             taskSettings,
             chunkingSettings,
@@ -130,7 +235,8 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase {
         JinaAIEmbeddingsTaskSettings taskSettings,
         @Nullable Integer tokenLimit,
         @Nullable Integer dimensions,
-        String model
+        String model,
+        @Nullable JinaAIEmbeddingType embeddingType
     ) {
         return new JinaAIEmbeddingsModel(
             "id",
@@ -139,7 +245,8 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase {
                 new JinaAIServiceSettings(url, model, null),
                 SimilarityMeasure.DOT_PRODUCT,
                 dimensions,
-                tokenLimit
+                tokenLimit,
+                embeddingType
             ),
             taskSettings,
             null,
@@ -154,12 +261,19 @@ public class JinaAIEmbeddingsModelTests extends ESTestCase {
         @Nullable Integer tokenLimit,
         @Nullable Integer dimensions,
         String model,
-        @Nullable SimilarityMeasure similarityMeasure
+        @Nullable SimilarityMeasure similarityMeasure,
+        @Nullable JinaAIEmbeddingType embeddingType
     ) {
         return new JinaAIEmbeddingsModel(
             "id",
             "service",
-            new JinaAIEmbeddingsServiceSettings(new JinaAIServiceSettings(url, model, null), similarityMeasure, dimensions, tokenLimit),
+            new JinaAIEmbeddingsServiceSettings(
+                new JinaAIServiceSettings(url, model, null),
+                similarityMeasure,
+                dimensions,
+                tokenLimit,
+                embeddingType
+            ),
             taskSettings,
             null,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))

+ 81 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/embeddings/JinaAIEmbeddingsServiceSettingsTests.java

@@ -7,16 +7,18 @@
 
 package org.elasticsearch.xpack.inference.services.jinaai.embeddings;
 
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.SimilarityMeasure;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
@@ -35,7 +37,7 @@ import java.util.Map;
 
 import static org.hamcrest.Matchers.is;
 
-public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<JinaAIEmbeddingsServiceSettings> {
+public class JinaAIEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase<JinaAIEmbeddingsServiceSettings> {
     public static JinaAIEmbeddingsServiceSettings createRandom() {
         SimilarityMeasure similarityMeasure = null;
         Integer dims = null;
@@ -44,8 +46,9 @@ public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializin
         Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256);
 
         var commonSettings = JinaAIServiceSettingsTests.createRandom();
+        var embeddingType = randomFrom(JinaAIEmbeddingType.values());
 
-        return new JinaAIEmbeddingsServiceSettings(commonSettings, similarityMeasure, dims, maxInputTokens);
+        return new JinaAIEmbeddingsServiceSettings(commonSettings, similarityMeasure, dims, maxInputTokens, embeddingType);
     }
 
     public void testFromMap() {
@@ -79,7 +82,8 @@ public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializin
                     new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null),
                     SimilarityMeasure.DOT_PRODUCT,
                     dims,
-                    maxInputTokens
+                    maxInputTokens,
+                    JinaAIEmbeddingType.FLOAT
                 )
             )
         );
@@ -116,7 +120,48 @@ public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializin
                     new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null),
                     SimilarityMeasure.DOT_PRODUCT,
                     dims,
-                    maxInputTokens
+                    maxInputTokens,
+                    JinaAIEmbeddingType.FLOAT
+                )
+            )
+        );
+    }
+
+    public void testFromMap_WithEmbeddingType() {
+        var url = "https://www.abc.com";
+        var similarity = SimilarityMeasure.DOT_PRODUCT.toString();
+        var dims = 1536;
+        var maxInputTokens = 512;
+        var model = "model";
+        var serviceSettings = JinaAIEmbeddingsServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    ServiceFields.URL,
+                    url,
+                    ServiceFields.SIMILARITY,
+                    similarity,
+                    ServiceFields.DIMENSIONS,
+                    dims,
+                    ServiceFields.MAX_INPUT_TOKENS,
+                    maxInputTokens,
+                    JinaAIServiceSettings.MODEL_ID,
+                    model,
+                    JinaAIEmbeddingsServiceSettings.EMBEDDING_TYPE,
+                    JinaAIEmbeddingType.BIT.toString()
+                )
+            ),
+            ConfigurationParseContext.REQUEST
+        );
+
+        MatcherAssert.assertThat(
+            serviceSettings,
+            is(
+                new JinaAIEmbeddingsServiceSettings(
+                    new JinaAIServiceSettings(ServiceUtils.createUri(url), model, null),
+                    SimilarityMeasure.DOT_PRODUCT,
+                    dims,
+                    maxInputTokens,
+                    JinaAIEmbeddingType.BIT
                 )
             )
         );
@@ -146,7 +191,8 @@ public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializin
             new JinaAIServiceSettings("url", "model", new RateLimitSettings(3)),
             SimilarityMeasure.COSINE,
             5,
-            10
+            10,
+            JinaAIEmbeddingType.FLOAT
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -154,7 +200,8 @@ public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializin
         String xContentResult = Strings.toString(builder);
         assertThat(xContentResult, is("""
             {"url":"url","model_id":"model",""" + """
-            "rate_limit":{"requests_per_minute":3},"similarity":"cosine","dimensions":5,"max_input_tokens":10}"""));
+            "rate_limit":{"requests_per_minute":3},""" + """
+            "similarity":"cosine","dimensions":5,"max_input_tokens":10,"embedding_type":"float"}"""));
     }
 
     @Override
@@ -172,6 +219,23 @@ public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializin
         return randomValueOtherThan(instance, JinaAIEmbeddingsServiceSettingsTests::createRandom);
     }
 
+    @Override
+    protected JinaAIEmbeddingsServiceSettings mutateInstanceForVersion(JinaAIEmbeddingsServiceSettings instance, TransportVersion version) {
+        if (version.onOrAfter(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED)
+            || version.isPatchFrom(TransportVersions.JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19)) {
+            return instance;
+        }
+
+        // default to null embedding type if node is on a version before embedding type was introduced
+        return new JinaAIEmbeddingsServiceSettings(
+            instance.getCommonSettings(),
+            instance.similarity(),
+            instance.dimensions(),
+            instance.maxInputTokens(),
+            null
+        );
+    }
+
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
         List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
@@ -180,8 +244,17 @@ public class JinaAIEmbeddingsServiceSettingsTests extends AbstractWireSerializin
         return new NamedWriteableRegistry(entries);
     }
 
-    public static Map<String, Object> getServiceSettingsMap(@Nullable String url, String model) {
+    public static Map<String, Object> getServiceSettingsMap(
+        @Nullable String url,
+        String model,
+        @Nullable JinaAIEmbeddingType embeddingType
+    ) {
         var map = new HashMap<>(JinaAIServiceSettingsTests.getServiceSettingsMap(url, model));
+
+        if (embeddingType != null) {
+            map.put(JinaAIEmbeddingsServiceSettings.EMBEDDING_TYPE, embeddingType.toString());
+        }
+
         return map;
     }
 }