|
@@ -15,6 +15,8 @@ import org.elasticsearch.test.http.MockResponse;
|
|
|
import org.elasticsearch.test.http.MockWebServer;
|
|
|
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
|
|
|
import org.hamcrest.Matchers;
|
|
|
+import org.junit.AfterClass;
|
|
|
+import org.junit.BeforeClass;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.List;
|
|
@@ -39,7 +41,7 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
|
|
|
super(upgradedNodes);
|
|
|
}
|
|
|
|
|
|
- // @BeforeClass
|
|
|
+ @BeforeClass
|
|
|
public static void startWebServer() throws IOException {
|
|
|
cohereEmbeddingsServer = new MockWebServer();
|
|
|
cohereEmbeddingsServer.start();
|
|
@@ -48,58 +50,74 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
|
|
|
cohereRerankServer.start();
|
|
|
}
|
|
|
|
|
|
- // @AfterClass // for the awaitsfix
|
|
|
+ @AfterClass
|
|
|
public static void shutdown() {
|
|
|
cohereEmbeddingsServer.close();
|
|
|
cohereRerankServer.close();
|
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107887")
|
|
|
public void testCohereEmbeddings() throws IOException {
|
|
|
var embeddingsSupported = getOldClusterTestVersion().onOrAfter(COHERE_EMBEDDINGS_ADDED);
|
|
|
+ // `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS
|
|
|
+ String oldClusterEndpointIdentifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models";
|
|
|
assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported);
|
|
|
|
|
|
final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
|
|
|
final String oldClusterIdFloat = "old-cluster-embeddings-float";
|
|
|
|
|
|
+ var testTaskType = TaskType.TEXT_EMBEDDING;
|
|
|
+
|
|
|
if (isOldCluster()) {
|
|
|
// queue a response as PUT will call the service
|
|
|
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
|
|
|
- put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
|
|
+ put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
|
|
|
// float model
|
|
|
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
|
|
|
- put(oldClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
|
|
+ put(oldClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType);
|
|
|
|
|
|
- var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterIdInt8).get("endpoints");
|
|
|
- assertThat(configs, hasSize(1));
|
|
|
- assertEquals("cohere", configs.get(0).get("service"));
|
|
|
- var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
- assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
|
|
- var embeddingType = serviceSettings.get("embedding_type");
|
|
|
- // An upgraded node will report the embedding type as byte, the old node int8
|
|
|
- assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));
|
|
|
-
|
|
|
- assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
|
|
|
- assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
|
|
|
+ {
|
|
|
+ var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterIdInt8).get(oldClusterEndpointIdentifier);
|
|
|
+ assertThat(configs, hasSize(1));
|
|
|
+ assertEquals("cohere", configs.get(0).get("service"));
|
|
|
+ var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
+ assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
|
|
+ var embeddingType = serviceSettings.get("embedding_type");
|
|
|
+ // An upgraded node will report the embedding type as byte, the old node int8
|
|
|
+ assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));
|
|
|
+ assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
|
|
|
+ }
|
|
|
+ {
|
|
|
+ var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterIdFloat).get(oldClusterEndpointIdentifier);
|
|
|
+ assertThat(configs, hasSize(1));
|
|
|
+ assertEquals("cohere", configs.get(0).get("service"));
|
|
|
+ var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
+ assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
|
|
+ assertThat(serviceSettings, hasEntry("embedding_type", "float"));
|
|
|
+ assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
|
|
|
+ }
|
|
|
} else if (isMixedCluster()) {
|
|
|
- var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterIdInt8).get("endpoints");
|
|
|
- assertEquals("cohere", configs.get(0).get("service"));
|
|
|
- var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
- assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
|
|
- var embeddingType = serviceSettings.get("embedding_type");
|
|
|
- // An upgraded node will report the embedding type as byte, an old node int8
|
|
|
- assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));
|
|
|
-
|
|
|
- configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterIdFloat).get("endpoints");
|
|
|
- serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
- assertThat(serviceSettings, hasEntry("embedding_type", "float"));
|
|
|
-
|
|
|
- assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
|
|
|
- assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
|
|
|
+ {
|
|
|
+ var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterIdInt8);
|
|
|
+ assertEquals("cohere", configs.get(0).get("service"));
|
|
|
+ var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
+ assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
|
|
+ var embeddingType = serviceSettings.get("embedding_type");
|
|
|
+ // An upgraded node will report the embedding type as byte, an old node int8
|
|
|
+ assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));
|
|
|
+ assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
|
|
|
+ }
|
|
|
+ {
|
|
|
+ var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterIdFloat);
|
|
|
+ assertEquals("cohere", configs.get(0).get("service"));
|
|
|
+ var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
+ assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
|
|
+ assertThat(serviceSettings, hasEntry("embedding_type", "float"));
|
|
|
+ assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
|
|
|
+ }
|
|
|
} else if (isUpgradedCluster()) {
|
|
|
// check old cluster model
|
|
|
- var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, oldClusterIdInt8).get("endpoints");
|
|
|
+ var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterIdInt8).get("endpoints");
|
|
|
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
|
|
assertThat(serviceSettings, hasEntry("embedding_type", "byte"));
|
|
@@ -114,9 +132,9 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
|
|
|
final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte";
|
|
|
|
|
|
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
|
|
|
- put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
|
|
+ put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType);
|
|
|
|
|
|
- configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdByte).get("endpoints");
|
|
|
+ configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdByte).get("endpoints");
|
|
|
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
assertThat(serviceSettings, hasEntry("embedding_type", "byte"));
|
|
|
|
|
@@ -127,9 +145,9 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
|
|
|
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8";
|
|
|
|
|
|
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
|
|
|
- put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
|
|
+ put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
|
|
|
|
|
|
- configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdInt8).get("endpoints");
|
|
|
+ configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdInt8).get("endpoints");
|
|
|
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte
|
|
|
|
|
@@ -139,9 +157,9 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
|
|
|
{
|
|
|
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float";
|
|
|
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
|
|
|
- put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
|
|
+ put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType);
|
|
|
|
|
|
- configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdFloat).get("endpoints");
|
|
|
+ configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdFloat).get("endpoints");
|
|
|
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
assertThat(serviceSettings, hasEntry("embedding_type", "float"));
|
|
|
|
|
@@ -169,22 +187,25 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
|
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107887")
|
|
|
public void testRerank() throws IOException {
|
|
|
var rerankSupported = getOldClusterTestVersion().onOrAfter(COHERE_RERANK_ADDED);
|
|
|
+ String old_cluster_endpoint_identifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models";
|
|
|
assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported);
|
|
|
|
|
|
final String oldClusterId = "old-cluster-rerank";
|
|
|
final String upgradedClusterId = "upgraded-cluster-rerank";
|
|
|
|
|
|
+ var testTaskType = TaskType.RERANK;
|
|
|
+
|
|
|
if (isOldCluster()) {
|
|
|
- put(oldClusterId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
|
|
|
- var configs = (List<Map<String, Object>>) get(TaskType.RERANK, oldClusterId).get("endpoints");
|
|
|
+ put(oldClusterId, rerankConfig(getUrl(cohereRerankServer)), testTaskType);
|
|
|
+ var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier);
|
|
|
assertThat(configs, hasSize(1));
|
|
|
|
|
|
assertRerank(oldClusterId);
|
|
|
} else if (isMixedCluster()) {
|
|
|
- var configs = (List<Map<String, Object>>) get(TaskType.RERANK, oldClusterId).get("endpoints");
|
|
|
+ var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterId);
|
|
|
+
|
|
|
assertEquals("cohere", configs.get(0).get("service"));
|
|
|
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0"));
|
|
@@ -195,7 +216,7 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
|
|
|
|
|
|
} else if (isUpgradedCluster()) {
|
|
|
// check old cluster model
|
|
|
- var configs = (List<Map<String, Object>>) get(TaskType.RERANK, oldClusterId).get("endpoints");
|
|
|
+ var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get("endpoints");
|
|
|
assertEquals("cohere", configs.get(0).get("service"));
|
|
|
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
|
|
assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0"));
|
|
@@ -205,7 +226,7 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
|
|
|
assertRerank(oldClusterId);
|
|
|
|
|
|
// New endpoint
|
|
|
- put(upgradedClusterId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
|
|
|
+ put(upgradedClusterId, rerankConfig(getUrl(cohereRerankServer)), testTaskType);
|
|
|
configs = (List<Map<String, Object>>) get(upgradedClusterId).get("endpoints");
|
|
|
assertThat(configs, hasSize(1));
|
|
|
|