|
@@ -10,6 +10,8 @@ package org.elasticsearch.xpack.inference.external.response.cohere;
|
|
|
import org.apache.http.HttpResponse;
|
|
|
import org.elasticsearch.inference.InferenceServiceResults;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
+import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
|
|
|
+import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
|
|
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
|
@@ -182,10 +184,7 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
|
|
);
|
|
|
|
|
|
- MatcherAssert.assertThat(
|
|
|
- parsedResults.embeddings(),
|
|
|
- is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))
|
|
|
- );
|
|
|
+ MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))));
|
|
|
}
|
|
|
|
|
|
public void testFromResponse_ParsesBytes() throws IOException {
|
|
@@ -220,9 +219,47 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
|
|
);
|
|
|
|
|
|
+ MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOException {
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
|
|
+ "texts": [
|
|
|
+ "hello"
|
|
|
+ ],
|
|
|
+ "embeddings": {
|
|
|
+ "binary": [
|
|
|
+ [
|
|
|
+ -55,
|
|
|
+ 74,
|
|
|
+ 101,
|
|
|
+ 67,
|
|
|
+ 83
|
|
|
+ ]
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "2"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "input_tokens": 1
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "response_type": "embeddings_by_type"
|
|
|
+ }
|
|
|
+ """;
|
|
|
+
|
|
|
+ InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
|
|
|
+ mock(Request.class),
|
|
|
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
|
|
+ );
|
|
|
+
|
|
|
MatcherAssert.assertThat(
|
|
|
parsedResults.embeddings(),
|
|
|
- is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))
|
|
|
+ is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 })))
|
|
|
);
|
|
|
}
|
|
|
|
|
@@ -318,6 +355,59 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary() throws IOException {
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
|
|
+ "texts": [
|
|
|
+ "hello",
|
|
|
+ "goodbye"
|
|
|
+ ],
|
|
|
+ "embeddings": {
|
|
|
+ "binary": [
|
|
|
+ [
|
|
|
+ -55,
|
|
|
+ 74,
|
|
|
+ 101,
|
|
|
+ 67
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 34,
|
|
|
+ -64,
|
|
|
+ 97,
|
|
|
+ 65,
|
|
|
+ -42
|
|
|
+ ]
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "2"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "input_tokens": 1
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "response_type": "embeddings_by_type"
|
|
|
+ }
|
|
|
+ """;
|
|
|
+
|
|
|
+ InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse(
|
|
|
+ mock(Request.class),
|
|
|
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
|
|
+ );
|
|
|
+
|
|
|
+ MatcherAssert.assertThat(
|
|
|
+ parsedResults.embeddings(),
|
|
|
+ is(
|
|
|
+ List.of(
|
|
|
+ new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }),
|
|
|
+ new InferenceByteEmbedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 })
|
|
|
+ )
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() {
|
|
|
String responseJson = """
|
|
|
{
|
|
@@ -433,6 +523,82 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
|
|
|
}
|
|
|
|
|
|
+ public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Negative() {
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
|
|
+ "texts": [
|
|
|
+ "hello"
|
|
|
+ ],
|
|
|
+ "embeddings": {
|
|
|
+ "binary": [
|
|
|
+ [
|
|
|
+ -129,
|
|
|
+ 127
|
|
|
+ ]
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "2"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "input_tokens": 1
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "response_type": "embeddings_by_type"
|
|
|
+ }
|
|
|
+ """;
|
|
|
+
|
|
|
+ var thrownException = expectThrows(
|
|
|
+ IllegalArgumentException.class,
|
|
|
+ () -> CohereEmbeddingsResponseEntity.fromResponse(
|
|
|
+ mock(Request.class),
|
|
|
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
|
|
+ )
|
|
|
+ );
|
|
|
+
|
|
|
+ MatcherAssert.assertThat(thrownException.getMessage(), is("Value [-129] is out of range for a byte"));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Positive() {
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
|
|
+ "texts": [
|
|
|
+ "hello"
|
|
|
+ ],
|
|
|
+ "embeddings": {
|
|
|
+ "binary": [
|
|
|
+ [
|
|
|
+ -128,
|
|
|
+ 128
|
|
|
+ ]
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "2"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "input_tokens": 1
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "response_type": "embeddings_by_type"
|
|
|
+ }
|
|
|
+ """;
|
|
|
+
|
|
|
+ var thrownException = expectThrows(
|
|
|
+ IllegalArgumentException.class,
|
|
|
+ () -> CohereEmbeddingsResponseEntity.fromResponse(
|
|
|
+ mock(Request.class),
|
|
|
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
|
|
+ )
|
|
|
+ );
|
|
|
+
|
|
|
+ MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte"));
|
|
|
+ }
|
|
|
+
|
|
|
public void testFromResponse_FailsToFindAValidEmbeddingType() {
|
|
|
String responseJson = """
|
|
|
{
|
|
@@ -470,7 +636,7 @@ public class CohereEmbeddingsResponseEntityTests extends ESTestCase {
|
|
|
|
|
|
MatcherAssert.assertThat(
|
|
|
thrownException.getMessage(),
|
|
|
- is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [float, int8]")
|
|
|
+ is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [binary, float, int8]")
|
|
|
);
|
|
|
}
|
|
|
}
|