浏览代码

Adding support for binary embedding type to Cohere service embedding type (#120751)

* Adding support for binary embedding type to Cohere service embedding type

* Returning response in separate text_embedding_bits field

* Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java

Co-authored-by: David Kyle <david.kyle@elastic.co>

* Update docs/changelog/120751.yaml

* Reverting docs change

---------

Co-authored-by: David Kyle <david.kyle@elastic.co>
Ying Mao 8 月之前
父节点
当前提交
89d71e1f6c
共有 18 个文件被更改,包括 690 次插入121 次删除
  1. 5 0
      docs/changelog/120751.yaml
  2. 1 1
      server/src/main/java/org/elasticsearch/TransportVersions.java
  3. 95 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java
  4. 109 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingBitResults.java
  5. 1 80
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java
  6. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java
  7. 14 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java
  8. 17 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java
  9. 5 4
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
  10. 32 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java
  11. 47 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java
  12. 172 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java
  13. 2 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java
  14. 135 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingBitResultsTests.java
  15. 10 16
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingByteResultsTests.java
  16. 2 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java
  17. 25 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java
  18. 14 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java

+ 5 - 0
docs/changelog/120751.yaml

@@ -0,0 +1,5 @@
+pr: 120751
+summary: Adding support for binary embedding type to Cohere service embedding type
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 1
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -173,7 +173,7 @@ public class TransportVersions {
     public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00);
     public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00);
     public static final TransportVersion ELASTICSEARCH_9_0 = def(9_000_0_00);
-
+    public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
     /*
      * STOP! READ THIS FIRST! No, really,
      *        ____ _____ ___  ____  _        ____  _____    _    ____    _____ _   _ ___ ____    _____ ___ ____  ____ _____ _

+ 95 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ *
+ * this file was contributed to by a generative AI
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
+    public static final String EMBEDDING = "embedding";
+
+    public InferenceByteEmbedding(StreamInput in) throws IOException {
+        this(in.readByteArray());
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeByteArray(values);
+    }
+
+    public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
+        byte[] embeddingValues = new byte[embeddingValuesList.size()];
+        for (int i = 0; i < embeddingValuesList.size(); i++) {
+            embeddingValues[i] = embeddingValuesList.get(i);
+        }
+        return new InferenceByteEmbedding(embeddingValues);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        builder.startArray(EMBEDDING);
+        for (byte value : values) {
+            builder.value(value);
+        }
+        builder.endArray();
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String toString() {
+        return Strings.toString(this);
+    }
+
+    float[] toFloatArray() {
+        float[] floatArray = new float[values.length];
+        for (int i = 0; i < values.length; i++) {
+            floatArray[i] = ((Byte) values[i]).floatValue();
+        }
+        return floatArray;
+    }
+
+    double[] toDoubleArray() {
+        double[] doubleArray = new double[values.length];
+        for (int i = 0; i < values.length; i++) {
+            doubleArray[i] = ((Byte) values[i]).doubleValue();
+        }
+        return doubleArray;
+    }
+
+    @Override
+    public int getSize() {
+        return values().length;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
+        return Arrays.equals(values, embedding.values);
+    }
+
+    @Override
+    public int hashCode() {
+        return Arrays.hashCode(values);
+    }
+}

+ 109 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingBitResults.java

@@ -0,0 +1,109 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ *
+ * this file was contributed to by a generative AI
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
+import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Writes a text embedding result in the follow json format
+ * {
+ *     "text_embedding_bytes": [
+ *         {
+ *             "embedding": [
+ *                 23
+ *             ]
+ *         },
+ *         {
+ *             "embedding": [
+ *                 -23
+ *             ]
+ *         }
+ *     ]
+ * }
+ */
+public record InferenceTextEmbeddingBitResults(List<InferenceByteEmbedding> embeddings) implements InferenceServiceResults, TextEmbedding {
+    public static final String NAME = "text_embedding_service_bit_results";
+    public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
+
+    public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException {
+        this(in.readCollectionAsList(InferenceByteEmbedding::new));
+    }
+
+    @Override
+    public int getFirstEmbeddingSize() {
+        return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
+    }
+
+    @Override
+    public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
+        return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BITS, embeddings.iterator());
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeCollection(embeddings);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public List<? extends InferenceResults> transformToCoordinationFormat() {
+        return embeddings.stream()
+            .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
+            .toList();
+    }
+
+    @Override
+    @SuppressWarnings("deprecation")
+    public List<? extends InferenceResults> transformToLegacyFormat() {
+        var legacyEmbedding = new LegacyTextEmbeddingResults(
+            embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList()
+        );
+
+        return List.of(legacyEmbedding);
+    }
+
+    public Map<String, Object> asMap() {
+        Map<String, Object> map = new LinkedHashMap<>();
+        map.put(TEXT_EMBEDDING_BITS, embeddings);
+
+        return map;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        InferenceTextEmbeddingBitResults that = (InferenceTextEmbeddingBitResults) o;
+        return Objects.equals(embeddings, that.embeddings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(embeddings);
+    }
+}

+ 1 - 80
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java

@@ -9,21 +9,16 @@
 
 package org.elasticsearch.xpack.core.inference.results;
 
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.xcontent.ToXContent;
-import org.elasticsearch.xcontent.ToXContentObject;
-import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -33,7 +28,7 @@ import java.util.Objects;
 /**
  * Writes a text embedding result in the follow json format
  * {
- *     "text_embedding": [
+ *     "text_embedding_bytes": [
  *         {
  *             "embedding": [
  *                 23
@@ -111,78 +106,4 @@ public record InferenceTextEmbeddingByteResults(List<InferenceByteEmbedding> emb
     public int hashCode() {
         return Objects.hash(embeddings);
     }
-
-    public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
-        public static final String EMBEDDING = "embedding";
-
-        public InferenceByteEmbedding(StreamInput in) throws IOException {
-            this(in.readByteArray());
-        }
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {
-            out.writeByteArray(values);
-        }
-
-        public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
-            byte[] embeddingValues = new byte[embeddingValuesList.size()];
-            for (int i = 0; i < embeddingValuesList.size(); i++) {
-                embeddingValues[i] = embeddingValuesList.get(i);
-            }
-            return new InferenceByteEmbedding(embeddingValues);
-        }
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            builder.startObject();
-
-            builder.startArray(EMBEDDING);
-            for (byte value : values) {
-                builder.value(value);
-            }
-            builder.endArray();
-
-            builder.endObject();
-            return builder;
-        }
-
-        @Override
-        public String toString() {
-            return Strings.toString(this);
-        }
-
-        private float[] toFloatArray() {
-            float[] floatArray = new float[values.length];
-            for (int i = 0; i < values.length; i++) {
-                floatArray[i] = ((Byte) values[i]).floatValue();
-            }
-            return floatArray;
-        }
-
-        private double[] toDoubleArray() {
-            double[] doubleArray = new double[values.length];
-            for (int i = 0; i < values.length; i++) {
-                doubleArray[i] = ((Byte) values[i]).floatValue();
-            }
-            return doubleArray;
-        }
-
-        @Override
-        public int getSize() {
-            return values().length;
-        }
-
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) return true;
-            if (o == null || getClass() != o.getClass()) return false;
-            InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
-            return Arrays.equals(values, embedding.values);
-        }
-
-        @Override
-        public int hashCode() {
-            return Arrays.hashCode(values);
-        }
-    }
 }

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

@@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingB
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@@ -69,7 +70,7 @@ public class EmbeddingRequestChunker {
 
     private List<ChunkOffsetsAndInput> chunkedOffsets;
     private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
-    private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
+    private List<AtomicArray<List<InferenceByteEmbedding>>> byteResults;
     private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
     private AtomicArray<Exception> errors;
     private ActionListener<List<ChunkedInference>> finalListener;
@@ -389,9 +390,9 @@ public class EmbeddingRequestChunker {
 
     private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
         ChunkOffsetsAndInput chunks,
-        AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>> debatchedResults
+        AtomicArray<List<InferenceByteEmbedding>> debatchedResults
     ) {
-        var all = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
+        var all = new ArrayList<InferenceByteEmbedding>();
         for (int i = 0; i < debatchedResults.length(); i++) {
             var subBatch = debatchedResults.get(i);
             all.addAll(subBatch);

+ 14 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java

@@ -17,6 +17,8 @@ import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentParserConfiguration;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
+import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.inference.external.http.HttpResult;
@@ -43,7 +45,9 @@ public class CohereEmbeddingsResponseEntity {
         toLowerCase(CohereEmbeddingType.FLOAT),
         CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray,
         toLowerCase(CohereEmbeddingType.INT8),
-        CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray
+        CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray,
+        toLowerCase(CohereEmbeddingType.BINARY),
+        CohereEmbeddingsResponseEntity::parseBitEmbeddingsArray
     );
     private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes();
 
@@ -184,17 +188,24 @@ public class CohereEmbeddingsResponseEntity {
         );
     }
 
+    private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException {
+        // Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser
+        var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
+
+        return new InferenceTextEmbeddingBitResults(embeddingList);
+    }
+
     private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException {
         var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry);
 
         return new InferenceTextEmbeddingByteResults(embeddingList);
     }
 
-    private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
+    private static InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException {
         ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
         List<Byte> embeddingValuesList = parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry);
 
-        return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList);
+        return InferenceByteEmbedding.of(embeddingValuesList);
     }
 
     private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException {

+ 17 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java

@@ -36,18 +36,29 @@ public enum CohereEmbeddingType {
     /**
      * This is a synonym for INT8
      */
-    BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8);
+    BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8),
+    /**
+     * Use this when you want to get back binary embeddings. Valid only for v3 models.
+     */
+    BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT),
+    /**
+     * This is a synonym for BIT
+     */
+    BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT);
 
     private static final class RequestConstants {
         private static final String FLOAT = "float";
         private static final String INT8 = "int8";
+        private static final String BIT = "binary";
     }
 
     private static final Map<DenseVectorFieldMapper.ElementType, CohereEmbeddingType> ELEMENT_TYPE_TO_COHERE_EMBEDDING = Map.of(
         DenseVectorFieldMapper.ElementType.FLOAT,
         FLOAT,
         DenseVectorFieldMapper.ElementType.BYTE,
-        BYTE
+        BYTE,
+        DenseVectorFieldMapper.ElementType.BIT,
+        BIT
     );
     static final EnumSet<DenseVectorFieldMapper.ElementType> SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf(
         ELEMENT_TYPE_TO_COHERE_EMBEDDING.keySet()
@@ -116,6 +127,10 @@ public enum CohereEmbeddingType {
             return INT8;
         }
 
+        if (version.before(TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED) && embeddingType == BIT) {
+            return INT8;
+        }
+
         return embeddingType;
     }
 }

+ 5 - 4
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingB
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@@ -368,16 +369,16 @@ public class EmbeddingRequestChunkerTests extends ESTestCase {
 
         // 4 inputs in 2 batches
         {
-            var embeddings = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
+            var embeddings = new ArrayList<InferenceByteEmbedding>();
             for (int i = 0; i < batchSize; i++) {
-                embeddings.add(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { randomByte() }));
+                embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
             }
             batches.get(0).listener().onResponse(new InferenceTextEmbeddingByteResults(embeddings));
         }
         {
-            var embeddings = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
+            var embeddings = new ArrayList<InferenceByteEmbedding>();
             for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
-                embeddings.add(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { randomByte() }));
+                embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() }));
             }
             batches.get(1).listener().onResponse(new InferenceTextEmbeddingByteResults(embeddings));
         }

+ 32 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java

@@ -72,6 +72,38 @@ public class CohereEmbeddingsRequestEntityTests extends ESTestCase {
             {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}"""));
     }
 
+    public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException {
+        var entity = new CohereEmbeddingsRequestEntity(
+            List.of("abc"),
+            new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE),
+            "model",
+            CohereEmbeddingType.BINARY
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        MatcherAssert.assertThat(xContentResult, is("""
+            {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}"""));
+    }
+
+    public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException {
+        var entity = new CohereEmbeddingsRequestEntity(
+            List.of("abc"),
+            new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE),
+            "model",
+            CohereEmbeddingType.BIT
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        MatcherAssert.assertThat(xContentResult, is("""
+            {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}"""));
+    }
+
     public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException {
         var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null);
 

+ 47 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java

@@ -145,6 +145,53 @@ public class CohereEmbeddingsRequestTests extends ESTestCase {
         );
     }
 
+    public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() throws IOException {
+        var request = createRequest(
+            List.of("abc"),
+            CohereEmbeddingsModelTests.createModel(
+                "url",
+                "secret",
+                new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END),
+                null,
+                null,
+                "model",
+                CohereEmbeddingType.BIT
+            )
+        );
+
+        var httpRequest = request.createHttpRequest();
+        MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
+        MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+        MatcherAssert.assertThat(
+            httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
+            is(CohereUtils.ELASTIC_REQUEST_SOURCE)
+        );
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        MatcherAssert.assertThat(
+            requestMap,
+            is(
+                Map.of(
+                    "texts",
+                    List.of("abc"),
+                    "model",
+                    "model",
+                    "input_type",
+                    "search_query",
+                    "embedding_types",
+                    List.of("binary"),
+                    "truncate",
+                    "end"
+                )
+            )
+        );
+    }
+
     public void testCreateRequest_TruncateNone() throws IOException {
         var request = createRequest(
             List.of("abc"),

+ 172 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java

@@ -10,6 +10,8 @@ package org.elasticsearch.xpack.inference.external.response.cohere;
 import org.apache.http.HttpResponse;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
+import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.inference.external.http.HttpResult;
@@ -182,10 +184,7 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
             new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
         );
 
-        MatcherAssert.assertThat(
-            parsedResults.embeddings(),
-            is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))
-        );
+        MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))));
     }
 
     public void testFromResponse_ParsesBytes() throws IOException {
@@ -220,9 +219,47 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
             new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
         );
 
+        MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))));
+    }
+
+    public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOException {
+        String responseJson = """
+            {
+                "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+                "texts": [
+                    "hello"
+                ],
+                "embeddings": {
+                    "binary": [
+                        [
+                            -55,
+                            74,
+                            101,
+                            67,
+                            83
+                        ]
+                    ]
+                },
+                "meta": {
+                    "api_version": {
+                        "version": "2"
+                    },
+                    "billed_units": {
+                        "input_tokens": 1
+                    }
+                },
+                "response_type": "embeddings_by_type"
+            }
+            """;
+
+        InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
+            mock(Request.class),
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
         MatcherAssert.assertThat(
             parsedResults.embeddings(),
-            is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))
+            is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
         );
     }
 
@@ -318,6 +355,59 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
         );
     }
 
+    public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary() throws IOException {
+        String responseJson = """
+            {
+                "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+                "texts": [
+                    "hello",
+                    "goodbye"
+                ],
+                "embeddings": {
+                    "binary": [
+                        [
+                            -55,
+                            74,
+                            101,
+                            67
+                        ],
+                        [
+                            34,
+                            -64,
+                            97,
+                            65,
+                            -42
+                        ]
+                    ]
+                },
+                "meta": {
+                    "api_version": {
+                        "version": "2"
+                    },
+                    "billed_units": {
+                        "input_tokens": 1
+                    }
+                },
+                "response_type": "embeddings_by_type"
+            }
+            """;
+
+        InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
+            mock(Request.class),
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        MatcherAssert.assertThat(
+            parsedResults.embeddings(),
+            is(
+                List.of(
+                    new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }),
+                    new InferenceByteEmbedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 })
+                )
+            )
+        );
+    }
+
     public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() {
         String responseJson = """
             {
@@ -433,6 +523,82 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
         MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
     }
 
+    public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Negative() {
+        String responseJson = """
+            {
+                "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+                "texts": [
+                    "hello"
+                ],
+                "embeddings": {
+                    "binary": [
+                        [
+                            -129,
+                            127
+                        ]
+                    ]
+                },
+                "meta": {
+                    "api_version": {
+                        "version": "2"
+                    },
+                    "billed_units": {
+                        "input_tokens": 1
+                    }
+                },
+                "response_type": "embeddings_by_type"
+            }
+            """;
+
+        var thrownException = expectThrows(
+            IllegalArgumentException.class,
+            () -> CohereEmbeddingsResponseEntity.fromResponse(
+                mock(Request.class),
+                new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+            )
+        );
+
+        MatcherAssert.assertThat(thrownException.getMessage(), is("Value [-129] is out of range for a byte"));
+    }
+
+    public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Positive() {
+        String responseJson = """
+            {
+                "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+                "texts": [
+                    "hello"
+                ],
+                "embeddings": {
+                    "binary": [
+                        [
+                            -128,
+                            128
+                        ]
+                    ]
+                },
+                "meta": {
+                    "api_version": {
+                        "version": "2"
+                    },
+                    "billed_units": {
+                        "input_tokens": 1
+                    }
+                },
+                "response_type": "embeddings_by_type"
+            }
+            """;
+
+        var thrownException = expectThrows(
+            IllegalArgumentException.class,
+            () -> CohereEmbeddingsResponseEntity.fromResponse(
+                mock(Request.class),
+                new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+            )
+        );
+
+        MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
+    }
+
     public void testFromResponse_FailsToFindAValidEmbeddingType() {
         String responseJson = """
             {
@@ -470,7 +636,7 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
 
         MatcherAssert.assertThat(
             thrownException.getMessage(),
-            is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [float, int8]")
+            is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [binary, float, int8]")
         );
     }
 }

+ 2 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java

@@ -21,6 +21,7 @@ import org.elasticsearch.test.rest.FakeRestRequest;
 import org.elasticsearch.test.rest.RestActionTestCase;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
 import org.junit.Before;
 
@@ -142,9 +143,7 @@ public class BaseInferenceActionTests extends RestActionTestCase {
 
     static InferenceAction.Response createResponse() {
         return new InferenceAction.Response(
-            new InferenceTextEmbeddingByteResults(
-                List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1 }))
-            )
+            new InferenceTextEmbeddingByteResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1 })))
         );
     }
 }

+ 135 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingBitResultsTests.java

@@ -0,0 +1,135 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.results;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
+import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
+import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class InferenceTextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase<InferenceTextEmbeddingBitResults> {
+    public static InferenceTextEmbeddingBitResults createRandomResults() {
+        int embeddings = randomIntBetween(1, 10);
+        List<InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
+
+        for (int i = 0; i < embeddings; i++) {
+            embeddingResults.add(createRandomEmbedding());
+        }
+
+        return new InferenceTextEmbeddingBitResults(embeddingResults);
+    }
+
+    private static InferenceByteEmbedding createRandomEmbedding() {
+        int columns = randomIntBetween(1, 10);
+        byte[] bytes = new byte[columns];
+
+        for (int i = 0; i < columns; i++) {
+            bytes[i] = randomByte();
+        }
+
+        return new InferenceByteEmbedding(bytes);
+    }
+
+    public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
+        var entity = new InferenceTextEmbeddingBitResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 })));
+
+        String xContentResult = Strings.toString(entity, true, true);
+        assertThat(xContentResult, is("""
+            {
+              "text_embedding_bits" : [
+                {
+                  "embedding" : [
+                    23
+                  ]
+                }
+              ]
+            }"""));
+    }
+
+    public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
+        var entity = new InferenceTextEmbeddingBitResults(
+            List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }), new InferenceByteEmbedding(new byte[] { (byte) 24 }))
+        );
+
+        String xContentResult = Strings.toString(entity, true, true);
+        assertThat(xContentResult, is("""
+            {
+              "text_embedding_bits" : [
+                {
+                  "embedding" : [
+                    23
+                  ]
+                },
+                {
+                  "embedding" : [
+                    24
+                  ]
+                }
+              ]
+            }"""));
+    }
+
+    public void testTransformToCoordinationFormat() {
+        var results = new InferenceTextEmbeddingBitResults(
+            List.of(
+                new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
+                new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
+            )
+        ).transformToCoordinationFormat();
+
+        assertThat(
+            results,
+            is(
+                List.of(
+                    new MlTextEmbeddingResults(InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false),
+                    new MlTextEmbeddingResults(InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false)
+                )
+            )
+        );
+    }
+
+    @Override
+    protected Writeable.Reader<InferenceTextEmbeddingBitResults> instanceReader() {
+        return InferenceTextEmbeddingBitResults::new;
+    }
+
+    @Override
+    protected InferenceTextEmbeddingBitResults createTestInstance() {
+        return createRandomResults();
+    }
+
+    @Override
+    protected InferenceTextEmbeddingBitResults mutateInstance(InferenceTextEmbeddingBitResults instance) throws IOException {
+        // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list
+        if (randomBoolean()) {
+            // -1 to remove at least one item from the list
+            int end = randomInt(instance.embeddings().size() - 1);
+            return new InferenceTextEmbeddingBitResults(instance.embeddings().subList(0, end));
+        } else {
+            List<InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
+            embeddings.add(createRandomEmbedding());
+            return new InferenceTextEmbeddingBitResults(embeddings);
+        }
+    }
+
+    public static Map<String, Object> buildExpectationByte(List<List<Byte>> embeddings) {
+        return Map.of(
+            InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS,
+            embeddings.stream().map(embedding -> Map.of(InferenceByteEmbedding.EMBEDDING, embedding)).toList()
+        );
+    }
+}

+ 10 - 16
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingByteResultsTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.results;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
 
@@ -23,7 +24,7 @@ import static org.hamcrest.Matchers.is;
 public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase<InferenceTextEmbeddingByteResults> {
     public static InferenceTextEmbeddingByteResults createRandomResults() {
         int embeddings = randomIntBetween(1, 10);
-        List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
+        List<InferenceByteEmbedding> embeddingResults = new ArrayList<>(embeddings);
 
         for (int i = 0; i < embeddings; i++) {
             embeddingResults.add(createRandomEmbedding());
@@ -32,7 +33,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
         return new InferenceTextEmbeddingByteResults(embeddingResults);
     }
 
-    private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding createRandomEmbedding() {
+    private static InferenceByteEmbedding createRandomEmbedding() {
         int columns = randomIntBetween(1, 10);
         byte[] bytes = new byte[columns];
 
@@ -40,13 +41,11 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
             bytes[i] = randomByte();
         }
 
-        return new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(bytes);
+        return new InferenceByteEmbedding(bytes);
     }
 
     public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException {
-        var entity = new InferenceTextEmbeddingByteResults(
-            List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23 }))
-        );
+        var entity = new InferenceTextEmbeddingByteResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 })));
 
         String xContentResult = Strings.toString(entity, true, true);
         assertThat(xContentResult, is("""
@@ -63,10 +62,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
 
     public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException {
         var entity = new InferenceTextEmbeddingByteResults(
-            List.of(
-                new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23 }),
-                new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 24 })
-            )
+            List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }), new InferenceByteEmbedding(new byte[] { (byte) 24 }))
         );
 
         String xContentResult = Strings.toString(entity, true, true);
@@ -90,8 +86,8 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
     public void testTransformToCoordinationFormat() {
         var results = new InferenceTextEmbeddingByteResults(
             List.of(
-                new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
-                new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
+                new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }),
+                new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 })
             )
         ).transformToCoordinationFormat();
 
@@ -124,7 +120,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
             int end = randomInt(instance.embeddings().size() - 1);
             return new InferenceTextEmbeddingByteResults(instance.embeddings().subList(0, end));
         } else {
-            List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
+            List<InferenceByteEmbedding> embeddings = new ArrayList<>(instance.embeddings());
             embeddings.add(createRandomEmbedding());
             return new InferenceTextEmbeddingByteResults(embeddings);
         }
@@ -133,9 +129,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializ
     public static Map<String, Object> buildExpectationByte(List<List<Byte>> embeddings) {
         return Map.of(
             InferenceTextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
-            embeddings.stream()
-                .map(embedding -> Map.of(InferenceTextEmbeddingByteResults.InferenceByteEmbedding.EMBEDDING, embedding))
-                .toList()
+            embeddings.stream().map(embedding -> Map.of(InferenceByteEmbedding.EMBEDDING, embedding)).toList()
         );
     }
 }

+ 2 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.results;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@@ -141,7 +142,7 @@ public class TextEmbeddingResultsTests extends AbstractWireSerializingTestCase<I
     public static Map<String, Object> buildExpectationByte(List<byte[]> embeddings) {
         return Map.of(
             InferenceTextEmbeddingByteResults.TEXT_EMBEDDING_BYTES,
-            embeddings.stream().map(InferenceTextEmbeddingByteResults.InferenceByteEmbedding::new).toList()
+            embeddings.stream().map(InferenceByteEmbedding::new).toList()
         );
     }
 

+ 25 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java

@@ -50,6 +50,27 @@ public class CohereEmbeddingTypeTests extends ESTestCase {
         );
     }
 
+    public void testTranslateToVersion_ReturnsInt8_WhenVersionIsBeforeBitEnumAddition_WhenSpecifyingBit() {
+        assertThat(
+            CohereEmbeddingType.translateToVersion(CohereEmbeddingType.BIT, new TransportVersion(8_840_0_00)),
+            is(CohereEmbeddingType.INT8)
+        );
+    }
+
+    public void testTranslateToVersion_ReturnsBit_WhenVersionOnBitEnumAddition_WhenSpecifyingBit() {
+        assertThat(
+            CohereEmbeddingType.translateToVersion(CohereEmbeddingType.BIT, TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED),
+            is(CohereEmbeddingType.BIT)
+        );
+    }
+
+    public void testTranslateToVersion_ReturnsFloat_WhenVersionOnBitEnumAddition_WhenSpecifyingFloat() {
+        assertThat(
+            CohereEmbeddingType.translateToVersion(CohereEmbeddingType.FLOAT, TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED),
+            is(CohereEmbeddingType.FLOAT)
+        );
+    }
+
     public void testFromElementType_CovertsFloatToCohereEmbeddingTypeFloat() {
         assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.FLOAT), is(CohereEmbeddingType.FLOAT));
     }
@@ -57,4 +78,8 @@ public class CohereEmbeddingTypeTests extends ESTestCase {
     public void testFromElementType_CovertsByteToCohereEmbeddingTypeByte() {
         assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.BYTE), is(CohereEmbeddingType.BYTE));
     }
+
+    public void testFromElementType_ConvertsBitToCohereEmbeddingTypeBinary() {
+        assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.BIT), is(CohereEmbeddingType.BIT));
+    }
 }

+ 14 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java

@@ -218,7 +218,7 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
             is(
                 Strings.format(
                     "Validation Failed: 1: [service_settings] Invalid value [abc] received. "
-                        + "[embedding_type] must be one of [byte, float, int8];"
+                        + "[embedding_type] must be one of [binary, bit, byte, float, int8];"
                 )
             )
         );
@@ -238,7 +238,7 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
             is(
                 Strings.format(
                     "Validation Failed: 1: [service_settings] Invalid value [abc] received. "
-                        + "[embedding_type] must be one of [byte, float];"
+                        + "[embedding_type] must be one of [bit, byte, float];"
                 )
             )
         );
@@ -289,6 +289,16 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
         );
     }
 
+    public void testFromMap_ConvertsBit_ToCohereEmbeddingTypeBit() {
+        assertThat(
+            CohereEmbeddingsServiceSettings.fromMap(
+                new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.BIT.toString())),
+                ConfigurationParseContext.REQUEST
+            ),
+            is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.BIT))
+        );
+    }
+
     public void testFromMap_PreservesEmbeddingTypeFloat() {
         assertThat(
             CohereEmbeddingsServiceSettings.fromMap(
@@ -314,6 +324,8 @@ public class CohereEmbeddingsServiceSettingsTests extends AbstractWireSerializin
         assertEquals(CohereEmbeddingType.BYTE, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("byte", validation));
         assertEquals(CohereEmbeddingType.INT8, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("int8", validation));
         assertEquals(CohereEmbeddingType.FLOAT, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("float", validation));
+        assertEquals(CohereEmbeddingType.BINARY, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("binary", validation));
+        assertEquals(CohereEmbeddingType.BIT, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("bit", validation));
         assertTrue(validation.validationErrors().isEmpty());
     }