|
@@ -0,0 +1,271 @@
|
|
|
+/*
|
|
|
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
|
+ * or more contributor license agreements. Licensed under the Elastic License
|
|
|
+ * 2.0; you may not use this file except in compliance with the Elastic License
|
|
|
+ * 2.0.
|
|
|
+ */
|
|
|
+
|
|
|
+package org.elasticsearch.xpack.inference.qa.mixed;
|
|
|
+
|
|
|
+import org.elasticsearch.Version;
|
|
|
+import org.elasticsearch.common.Strings;
|
|
|
+import org.elasticsearch.inference.TaskType;
|
|
|
+import org.elasticsearch.test.http.MockResponse;
|
|
|
+import org.elasticsearch.test.http.MockWebServer;
|
|
|
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
|
|
|
+import org.hamcrest.Matchers;
|
|
|
+import org.junit.AfterClass;
|
|
|
+import org.junit.BeforeClass;
|
|
|
+
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
|
|
|
+import static org.hamcrest.Matchers.empty;
|
|
|
+import static org.hamcrest.Matchers.hasEntry;
|
|
|
+import static org.hamcrest.Matchers.hasSize;
|
|
|
+import static org.hamcrest.Matchers.not;
|
|
|
+import static org.hamcrest.Matchers.oneOf;
|
|
|
+
|
|
|
+public class CohereServiceMixedIT extends BaseMixedTestCase {
|
|
|
+
|
|
|
+ private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0";
|
|
|
+ private static final String COHERE_RERANK_ADDED = "8.14.0";
|
|
|
+ private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0";
|
|
|
+ private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
|
|
|
+
|
|
|
+ private static MockWebServer cohereEmbeddingsServer;
|
|
|
+ private static MockWebServer cohereRerankServer;
|
|
|
+
|
|
|
+ @BeforeClass
|
|
|
+ public static void startWebServer() throws IOException {
|
|
|
+ cohereEmbeddingsServer = new MockWebServer();
|
|
|
+ cohereEmbeddingsServer.start();
|
|
|
+
|
|
|
+ cohereRerankServer = new MockWebServer();
|
|
|
+ cohereRerankServer.start();
|
|
|
+ }
|
|
|
+
|
|
|
+ @AfterClass
|
|
|
+ public static void shutdown() {
|
|
|
+ cohereEmbeddingsServer.close();
|
|
|
+ cohereRerankServer.close();
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testCohereEmbeddings() throws IOException {
|
|
|
+ var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_EMBEDDINGS_ADDED));
|
|
|
+ assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported);
|
|
|
+ assumeTrue(
|
|
|
+ "Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION,
|
|
|
+ bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
|
|
|
+ );
|
|
|
+
|
|
|
+ final String inferenceIdInt8 = "mixed-cluster-cohere-embeddings-int8";
|
|
|
+ final String inferenceIdFloat = "mixed-cluster-cohere-embeddings-float";
|
|
|
+
|
|
|
+ // queue a response as PUT will call the service
|
|
|
+ cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
|
|
|
+ put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
|
|
+
|
|
|
+ // float model
|
|
|
+ cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
|
|
|
+ put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
|
|
+
|
|
|
+ var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdInt8).get("endpoints");
|
|
|
+ assertEquals("cohere", configs.get(0).get("service"));
|
|
|
+ var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
+ assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
|
|
+ var embeddingType = serviceSettings.get("embedding_type");
|
|
|
+ // An upgraded node will report the embedding type as byte, an old node int8
|
|
|
+ assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));
|
|
|
+
|
|
|
+ configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdFloat).get("endpoints");
|
|
|
+ serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
+ assertThat(serviceSettings, hasEntry("embedding_type", "float"));
|
|
|
+
|
|
|
+ assertEmbeddingInference(inferenceIdInt8, CohereEmbeddingType.BYTE);
|
|
|
+ assertEmbeddingInference(inferenceIdFloat, CohereEmbeddingType.FLOAT);
|
|
|
+
|
|
|
+ delete(inferenceIdFloat);
|
|
|
+ delete(inferenceIdInt8);
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
|
|
|
+ switch (type) {
|
|
|
+ case INT8:
|
|
|
+ case BYTE:
|
|
|
+ cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
|
|
|
+ break;
|
|
|
+ case FLOAT:
|
|
|
+ cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
|
|
|
+ }
|
|
|
+
|
|
|
+ var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
|
|
|
+ assertThat(inferenceMap.entrySet(), not(empty()));
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testRerank() throws IOException {
|
|
|
+ var rerankSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_RERANK_ADDED));
|
|
|
+ assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported);
|
|
|
+ assumeTrue(
|
|
|
+ "Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION,
|
|
|
+ bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
|
|
|
+ );
|
|
|
+
|
|
|
+ final String inferenceId = "mixed-cluster-rerank";
|
|
|
+
|
|
|
+ put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
|
|
|
+ assertRerank(inferenceId);
|
|
|
+
|
|
|
+ var configs = (List<Map<String, Object>>) get(TaskType.RERANK, inferenceId).get("endpoints");
|
|
|
+ assertThat(configs, hasSize(1));
|
|
|
+ assertEquals("cohere", configs.get(0).get("service"));
|
|
|
+ var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
+ assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0"));
|
|
|
+ var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
|
|
|
+ assertThat(taskSettings, hasEntry("top_n", 3));
|
|
|
+
|
|
|
+ assertRerank(inferenceId);
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ private void assertRerank(String inferenceId) throws IOException {
|
|
|
+ cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
|
|
|
+ var inferenceMap = rerank(
|
|
|
+ inferenceId,
|
|
|
+ List.of("luke", "like", "leia", "chewy", "r2d2", "star", "wars"),
|
|
|
+ "star wars main character"
|
|
|
+ );
|
|
|
+ assertThat(inferenceMap.entrySet(), not(empty()));
|
|
|
+ }
|
|
|
+
|
|
|
+ private String embeddingConfigByte(String url) {
|
|
|
+ return embeddingConfigTemplate(url, "byte");
|
|
|
+ }
|
|
|
+
|
|
|
+ private String embeddingConfigInt8(String url) {
|
|
|
+ return embeddingConfigTemplate(url, "int8");
|
|
|
+ }
|
|
|
+
|
|
|
+ private String embeddingConfigFloat(String url) {
|
|
|
+ return embeddingConfigTemplate(url, "float");
|
|
|
+ }
|
|
|
+
|
|
|
+ private String embeddingConfigTemplate(String url, String embeddingType) {
|
|
|
+ return Strings.format("""
|
|
|
+ {
|
|
|
+ "service": "cohere",
|
|
|
+ "service_settings": {
|
|
|
+ "url": "%s",
|
|
|
+ "api_key": "XXXX",
|
|
|
+ "model_id": "embed-english-light-v3.0",
|
|
|
+ "embedding_type": "%s"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ """, url, embeddingType);
|
|
|
+ }
|
|
|
+
|
|
|
+ private String embeddingResponseByte() {
|
|
|
+ return """
|
|
|
+ {
|
|
|
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
|
|
+ "texts": [
|
|
|
+ "hello"
|
|
|
+ ],
|
|
|
+ "embeddings": [
|
|
|
+ [
|
|
|
+ 12,
|
|
|
+ 56
|
|
|
+ ]
|
|
|
+ ],
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "1"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "input_tokens": 1
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "response_type": "embeddings_bytes"
|
|
|
+ }
|
|
|
+ """;
|
|
|
+ }
|
|
|
+
|
|
|
+ private String embeddingResponseFloat() {
|
|
|
+ return """
|
|
|
+ {
|
|
|
+ "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
|
|
+ "texts": [
|
|
|
+ "hello"
|
|
|
+ ],
|
|
|
+ "embeddings": [
|
|
|
+ [
|
|
|
+ -0.0018434525,
|
|
|
+ 0.01777649
|
|
|
+ ]
|
|
|
+ ],
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "1"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "input_tokens": 1
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "response_type": "embeddings_floats"
|
|
|
+ }
|
|
|
+ """;
|
|
|
+ }
|
|
|
+
|
|
|
+ private String rerankConfig(String url) {
|
|
|
+ return Strings.format("""
|
|
|
+ {
|
|
|
+ "service": "cohere",
|
|
|
+ "service_settings": {
|
|
|
+ "api_key": "XXXX",
|
|
|
+ "model_id": "rerank-english-v3.0",
|
|
|
+ "url": "%s"
|
|
|
+ },
|
|
|
+ "task_settings": {
|
|
|
+ "return_documents": false,
|
|
|
+ "top_n": 3
|
|
|
+ }
|
|
|
+ }
|
|
|
+ """, url);
|
|
|
+ }
|
|
|
+
|
|
|
+ private String rerankResponse() {
|
|
|
+ return """
|
|
|
+ {
|
|
|
+ "index": "d0760819-5a73-4d58-b163-3956d3648b62",
|
|
|
+ "results": [
|
|
|
+ {
|
|
|
+ "index": 2,
|
|
|
+ "relevance_score": 0.98005307
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "index": 3,
|
|
|
+ "relevance_score": 0.27904198
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "index": 0,
|
|
|
+ "relevance_score": 0.10194652
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "1"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "search_units": 1
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ """;
|
|
|
+ }
|
|
|
+
|
|
|
+}
|