Browse Source

[ML] Add mixed cluster tests for inference (#108392)

* mixed cluster tests are executable

* add tests from upgrade tests

* [ML] Add mixed cluster tests for existing services

* clean up

* review improvements

* spotless

* remove blocked AzureOpenAI mixed IT

* improvements from DK review

* temp for testing

* refactoring and documentation

* Revert manual testing configs of "temp for testing"

This reverts parts of commit fca46fd2b6253accc010a2e2a8bf05edfff5ea9b.

* revert TESTING.asciidoc formatting

* Update TESTING.asciidoc to avoid reformatting

* add minimum version for tests to match minimum version in services

* spotless
Max Hniebergall 1 year ago
parent
commit
c88a6fe481

+ 11 - 5
TESTING.asciidoc

@@ -551,13 +551,19 @@ When running `./gradlew check`, minimal bwc checks are also run against compatib
 
 ==== BWC Testing against a specific remote/branch
 
-Sometimes a backward compatibility change spans two versions. A common case is a new functionality
-that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
-To test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of
-pulling the release branch from GitHub. You do so using the `bwc.remote` and `bwc.refspec.BRANCH` system properties:
+Sometimes a backward compatibility change spans two versions.
+A common case is a new functionality that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
+Another use case, since the introduction of serverless, is to test BWC against main in addition to the other released branches.
+To do so, specify the `bwc.refspec` remote and branch to use for the BWC build as `origin/main`.
+To test against main, you will also need to create a new version in link:./server/src/main/java/org/elasticsearch/Version.java[Version.java], 
+increment `elasticsearch` in link:./build-tools-internal/version.properties[version.properties], and hard-code the `project.version` for ml-cpp 
+in link:./x-pack/plugin/ml/build.gradle[ml/build.gradle].
+
+In general, to test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of pulling the release branch from GitHub.
+You do so using the `bwc.refspec.{VERSION}` system property:
 
 -------------------------------------------------
-./gradlew check -Dbwc.remote=${remote} -Dbwc.refspec.5.x=index_req_bwc_5.x
+./gradlew check -Dtests.bwc.refspec.8.15=origin/main
 -------------------------------------------------
 
 The branch needs to be available on the remote that the BWC makes of the

+ 37 - 0
x-pack/plugin/inference/qa/mixed-cluster/build.gradle

@@ -0,0 +1,37 @@
+import org.elasticsearch.gradle.Version
+import org.elasticsearch.gradle.VersionProperties
+import org.elasticsearch.gradle.util.GradleUtils
+import org.elasticsearch.gradle.internal.info.BuildParams
+import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask
+
+apply plugin: 'elasticsearch.internal-java-rest-test'
+apply plugin: 'elasticsearch.internal-test-artifact-base'
+apply plugin: 'elasticsearch.bwc-test'
+
+dependencies {
+  testImplementation project(path: ':x-pack:plugin:inference:qa:inference-service-tests')
+  compileOnly project(':x-pack:plugin:core')
+  javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
+  javaRestTestImplementation project(path: xpackModule('inference'))
+  clusterPlugins project(
+    ':x-pack:plugin:inference:qa:test-service-plugin'
+  )
+}
+
+// inference is available in 8.11 or later
+def supportedVersion = bwcVersion -> {
+  return bwcVersion.onOrAfter(Version.fromString("8.11.0"));
+}
+
+BuildParams.bwcVersions.withWireCompatible(supportedVersion) { bwcVersion, baseName ->
+  def javaRestTest = tasks.register("v${bwcVersion}#javaRestTest", StandaloneRestIntegTestTask) {
+    usesBwcDistribution(bwcVersion)
+    systemProperty("tests.old_cluster_version", bwcVersion)
+    maxParallelForks = 1
+  }
+
+  tasks.register(bwcTaskName(bwcVersion)) {
+    dependsOn javaRestTest
+  }
+}
+

+ 129 - 0
x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/BaseMixedTestCase.java

@@ -0,0 +1,129 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.qa.mixed;
+
+import org.apache.http.util.EntityUtils;
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.Response;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.test.rest.ESRestTestCase;
+import org.hamcrest.Matchers;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+public abstract class BaseMixedTestCase extends MixedClusterSpecTestCase {
+    protected static String getUrl(MockWebServer webServer) {
+        return Strings.format("http://%s:%s", webServer.getHostName(), webServer.getPort());
+    }
+
+    @Override
+    protected Settings restClientSettings() {
+        String token = ESRestTestCase.basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
+        return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
+    }
+
+    protected void delete(String inferenceId, TaskType taskType) throws IOException {
+        var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, inferenceId));
+        var response = ESRestTestCase.client().performRequest(request);
+        ESRestTestCase.assertOK(response);
+    }
+
+    protected void delete(String inferenceId) throws IOException {
+        var request = new Request("DELETE", Strings.format("_inference/%s", inferenceId));
+        var response = ESRestTestCase.client().performRequest(request);
+        ESRestTestCase.assertOK(response);
+    }
+
+    protected Map<String, Object> getAll() throws IOException {
+        var request = new Request("GET", "_inference/_all");
+        var response = ESRestTestCase.client().performRequest(request);
+        ESRestTestCase.assertOK(response);
+        return ESRestTestCase.entityAsMap(response);
+    }
+
+    protected Map<String, Object> get(String inferenceId) throws IOException {
+        var endpoint = Strings.format("_inference/%s", inferenceId);
+        var request = new Request("GET", endpoint);
+        var response = ESRestTestCase.client().performRequest(request);
+        ESRestTestCase.assertOK(response);
+        return ESRestTestCase.entityAsMap(response);
+    }
+
+    protected Map<String, Object> get(TaskType taskType, String inferenceId) throws IOException {
+        var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
+        var request = new Request("GET", endpoint);
+        var response = ESRestTestCase.client().performRequest(request);
+        ESRestTestCase.assertOK(response);
+        return ESRestTestCase.entityAsMap(response);
+    }
+
+    protected Map<String, Object> inference(String inferenceId, TaskType taskType, String input) throws IOException {
+        var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
+        var request = new Request("POST", endpoint);
+        request.setJsonEntity("{\"input\": [" + '"' + input + '"' + "]}");
+
+        var response = ESRestTestCase.client().performRequest(request);
+        ESRestTestCase.assertOK(response);
+        return ESRestTestCase.entityAsMap(response);
+    }
+
+    protected Map<String, Object> rerank(String inferenceId, List<String> inputs, String query) throws IOException {
+        var endpoint = Strings.format("_inference/rerank/%s", inferenceId);
+        var request = new Request("POST", endpoint);
+
+        StringBuilder body = new StringBuilder("{").append("\"query\":\"").append(query).append("\",").append("\"input\":[");
+
+        for (int i = 0; i < inputs.size(); i++) {
+            body.append("\"").append(inputs.get(i)).append("\"");
+            if (i < inputs.size() - 1) {
+                body.append(",");
+            }
+        }
+
+        body.append("]}");
+        request.setJsonEntity(body.toString());
+
+        var response = ESRestTestCase.client().performRequest(request);
+        ESRestTestCase.assertOK(response);
+        return ESRestTestCase.entityAsMap(response);
+    }
+
+    protected void put(String inferenceId, String modelConfig, TaskType taskType) throws IOException {
+        String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, inferenceId);
+        var request = new Request("PUT", endpoint);
+        request.setJsonEntity(modelConfig);
+        var response = ESRestTestCase.client().performRequest(request);
+        logger.warn("PUT response: {}", response.toString());
+        System.out.println("PUT response: " + response.toString());
+        ESRestTestCase.assertOKAndConsume(response);
+    }
+
+    protected static void assertOkOrCreated(Response response) throws IOException {
+        int statusCode = response.getStatusLine().getStatusCode();
+        // Once EntityUtils.toString(entity) is called the entity cannot be reused.
+        // Avoid that call with check here.
+        if (statusCode == 200 || statusCode == 201) {
+            return;
+        }
+
+        String responseStr = EntityUtils.toString(response.getEntity());
+        ESTestCase.assertThat(
+            responseStr,
+            response.getStatusLine().getStatusCode(),
+            Matchers.anyOf(Matchers.equalTo(200), Matchers.equalTo(201))
+        );
+    }
+}

+ 271 - 0
x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/CohereServiceMixedIT.java

@@ -0,0 +1,271 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.qa.mixed;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
+import org.hamcrest.Matchers;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.hasEntry;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.oneOf;
+
+public class CohereServiceMixedIT extends BaseMixedTestCase {
+
+    private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0";
+    private static final String COHERE_RERANK_ADDED = "8.14.0";
+    private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0";
+    private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
+
+    private static MockWebServer cohereEmbeddingsServer;
+    private static MockWebServer cohereRerankServer;
+
+    @BeforeClass
+    public static void startWebServer() throws IOException {
+        cohereEmbeddingsServer = new MockWebServer();
+        cohereEmbeddingsServer.start();
+
+        cohereRerankServer = new MockWebServer();
+        cohereRerankServer.start();
+    }
+
+    @AfterClass
+    public static void shutdown() {
+        cohereEmbeddingsServer.close();
+        cohereRerankServer.close();
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testCohereEmbeddings() throws IOException {
+        var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_EMBEDDINGS_ADDED));
+        assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported);
+        assumeTrue(
+            "Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION,
+            bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
+        );
+
+        final String inferenceIdInt8 = "mixed-cluster-cohere-embeddings-int8";
+        final String inferenceIdFloat = "mixed-cluster-cohere-embeddings-float";
+
+        // queue a response as PUT will call the service
+        cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
+        put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
+
+        // float model
+        cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
+        put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
+
+        var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdInt8).get("endpoints");
+        assertEquals("cohere", configs.get(0).get("service"));
+        var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
+        assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
+        var embeddingType = serviceSettings.get("embedding_type");
+        // An upgraded node will report the embedding type as byte, an old node int8
+        assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));
+
+        configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdFloat).get("endpoints");
+        serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
+        assertThat(serviceSettings, hasEntry("embedding_type", "float"));
+
+        assertEmbeddingInference(inferenceIdInt8, CohereEmbeddingType.BYTE);
+        assertEmbeddingInference(inferenceIdFloat, CohereEmbeddingType.FLOAT);
+
+        delete(inferenceIdFloat);
+        delete(inferenceIdInt8);
+
+    }
+
+    void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
+        switch (type) {
+            case INT8:
+            case BYTE:
+                cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
+                break;
+            case FLOAT:
+                cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
+        }
+
+        var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
+        assertThat(inferenceMap.entrySet(), not(empty()));
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testRerank() throws IOException {
+        var rerankSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_RERANK_ADDED));
+        assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported);
+        assumeTrue(
+            "Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION,
+            bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
+        );
+
+        final String inferenceId = "mixed-cluster-rerank";
+
+        put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
+        assertRerank(inferenceId);
+
+        var configs = (List<Map<String, Object>>) get(TaskType.RERANK, inferenceId).get("endpoints");
+        assertThat(configs, hasSize(1));
+        assertEquals("cohere", configs.get(0).get("service"));
+        var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
+        assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0"));
+        var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
+        assertThat(taskSettings, hasEntry("top_n", 3));
+
+        assertRerank(inferenceId);
+
+    }
+
+    private void assertRerank(String inferenceId) throws IOException {
+        cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
+        var inferenceMap = rerank(
+            inferenceId,
+            List.of("luke", "like", "leia", "chewy", "r2d2", "star", "wars"),
+            "star wars main character"
+        );
+        assertThat(inferenceMap.entrySet(), not(empty()));
+    }
+
+    private String embeddingConfigByte(String url) {
+        return embeddingConfigTemplate(url, "byte");
+    }
+
+    private String embeddingConfigInt8(String url) {
+        return embeddingConfigTemplate(url, "int8");
+    }
+
+    private String embeddingConfigFloat(String url) {
+        return embeddingConfigTemplate(url, "float");
+    }
+
+    private String embeddingConfigTemplate(String url, String embeddingType) {
+        return Strings.format("""
+            {
+                "service": "cohere",
+                "service_settings": {
+                    "url": "%s",
+                    "api_key": "XXXX",
+                    "model_id": "embed-english-light-v3.0",
+                    "embedding_type": "%s"
+                }
+            }
+            """, url, embeddingType);
+    }
+
+    private String embeddingResponseByte() {
+        return """
+            {
+                "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+                "texts": [
+                    "hello"
+                ],
+                "embeddings": [
+                    [
+                        12,
+                        56
+                    ]
+                ],
+                "meta": {
+                    "api_version": {
+                        "version": "1"
+                    },
+                    "billed_units": {
+                        "input_tokens": 1
+                    }
+                },
+                "response_type": "embeddings_bytes"
+            }
+            """;
+    }
+
+    private String embeddingResponseFloat() {
+        return """
+            {
+                "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
+                "texts": [
+                    "hello"
+                ],
+                "embeddings": [
+                    [
+                        -0.0018434525,
+                        0.01777649
+                    ]
+                ],
+                "meta": {
+                    "api_version": {
+                        "version": "1"
+                    },
+                    "billed_units": {
+                        "input_tokens": 1
+                    }
+                },
+                "response_type": "embeddings_floats"
+            }
+            """;
+    }
+
+    private String rerankConfig(String url) {
+        return Strings.format("""
+            {
+                "service": "cohere",
+                "service_settings": {
+                    "api_key": "XXXX",
+                    "model_id": "rerank-english-v3.0",
+                    "url": "%s"
+                },
+                "task_settings": {
+                    "return_documents": false,
+                    "top_n": 3
+                }
+            }
+            """, url);
+    }
+
+    private String rerankResponse() {
+        return """
+            {
+                "index": "d0760819-5a73-4d58-b163-3956d3648b62",
+                "results": [
+                    {
+                        "index": 2,
+                        "relevance_score": 0.98005307
+                    },
+                    {
+                        "index": 3,
+                        "relevance_score": 0.27904198
+                    },
+                    {
+                        "index": 0,
+                        "relevance_score": 0.10194652
+                    }
+                ],
+                "meta": {
+                    "api_version": {
+                        "version": "1"
+                    },
+                    "billed_units": {
+                        "search_units": 1
+                    }
+                }
+            }
+            """;
+    }
+
+}

+ 147 - 0
x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/HuggingFaceServiceMixedIT.java

@@ -0,0 +1,147 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.qa.mixed;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
+
+public class HuggingFaceServiceMixedIT extends BaseMixedTestCase {
+
+    private static final String HF_EMBEDDINGS_ADDED = "8.12.0";
+    private static final String HF_ELSER_ADDED = "8.12.0";
+    private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
+
+    private static MockWebServer embeddingsServer;
+    private static MockWebServer elserServer;
+
+    @BeforeClass
+    public static void startWebServer() throws IOException {
+        embeddingsServer = new MockWebServer();
+        embeddingsServer.start();
+
+        elserServer = new MockWebServer();
+        elserServer.start();
+    }
+
+    @AfterClass
+    public static void shutdown() {
+        embeddingsServer.close();
+        elserServer.close();
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testHFEmbeddings() throws IOException {
+        var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(HF_EMBEDDINGS_ADDED));
+        assumeTrue("Hugging Face embedding service added in " + HF_EMBEDDINGS_ADDED, embeddingsSupported);
+        assumeTrue(
+            "HuggingFace service requires at least " + MINIMUM_SUPPORTED_VERSION,
+            bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
+        );
+
+        final String inferenceId = "mixed-cluster-embeddings";
+
+        embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
+        put(inferenceId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING);
+        var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
+        assertThat(configs, hasSize(1));
+        assertEquals("hugging_face", configs.get(0).get("service"));
+        assertEmbeddingInference(inferenceId);
+    }
+
+    void assertEmbeddingInference(String inferenceId) throws IOException {
+        embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
+        var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
+        assertThat(inferenceMap.entrySet(), not(empty()));
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testElser() throws IOException {
+        var supported = bwcVersion.onOrAfter(Version.fromString(HF_ELSER_ADDED));
+        assumeTrue("HF elser service added in " + HF_ELSER_ADDED, supported);
+        assumeTrue(
+            "HuggingFace service requires at least " + MINIMUM_SUPPORTED_VERSION,
+            bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
+        );
+
+        final String inferenceId = "mixed-cluster-elser";
+        final String upgradedClusterId = "upgraded-cluster-elser";
+
+        put(inferenceId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING);
+
+        var configs = (List<Map<String, Object>>) get(TaskType.SPARSE_EMBEDDING, inferenceId).get("endpoints");
+        assertThat(configs, hasSize(1));
+        assertEquals("hugging_face", configs.get(0).get("service"));
+        assertElser(inferenceId);
+    }
+
+    private void assertElser(String inferenceId) throws IOException {
+        elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
+        var inferenceMap = inference(inferenceId, TaskType.SPARSE_EMBEDDING, "some text");
+        assertThat(inferenceMap.entrySet(), not(empty()));
+    }
+
+    private String embeddingConfig(String url) {
+        return Strings.format("""
+            {
+                "service": "hugging_face",
+                "service_settings": {
+                    "url": "%s",
+                    "api_key": "XXXX"
+                }
+            }
+            """, url);
+    }
+
+    private String embeddingResponse() {
+        return """
+            [
+                  [
+                      0.014539449,
+                      -0.015288644
+                  ]
+            ]
+            """;
+    }
+
+    private String elserConfig(String url) {
+        return Strings.format("""
+            {
+                "service": "hugging_face",
+                "service_settings": {
+                    "api_key": "XXXX",
+                    "url": "%s"
+                }
+            }
+            """, url);
+    }
+
+    private String elserResponse() {
+        return """
+            [
+                {
+                    ".": 0.133155956864357,
+                    "the": 0.6747211217880249
+                }
+            ]
+            """;
+    }
+
+}

+ 53 - 0
x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClusterSpecTestCase.java

@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.qa.mixed;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.test.cluster.ElasticsearchCluster;
+import org.elasticsearch.test.rest.ESRestTestCase;
+import org.elasticsearch.test.rest.TestFeatureService;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.ClassRule;
+
+public abstract class MixedClusterSpecTestCase extends ESRestTestCase {
+    @ClassRule
+    public static ElasticsearchCluster cluster = MixedClustersSpec.mixedVersionCluster();
+
+    @Override
+    protected String getTestRestCluster() {
+        return cluster.getHttpAddresses();
+    }
+
+    static final Version bwcVersion = Version.fromString(System.getProperty("tests.old_cluster_version"));
+
+    private static TestFeatureService oldClusterTestFeatureService = null;
+
+    @Before
+    public void extractOldClusterFeatures() {
+        if (oldClusterTestFeatureService == null) {
+            oldClusterTestFeatureService = testFeatureService;
+        }
+    }
+
+    protected static boolean oldClusterHasFeature(String featureId) {
+        assert oldClusterTestFeatureService != null;
+        return oldClusterTestFeatureService.clusterHasFeature(featureId);
+    }
+
+    protected static boolean oldClusterHasFeature(NodeFeature feature) {
+        return oldClusterHasFeature(feature.id());
+    }
+
+    @AfterClass
+    public static void cleanUp() {
+        oldClusterTestFeatureService = null;
+    }
+
+}

+ 25 - 0
x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/MixedClustersSpec.java

@@ -0,0 +1,25 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.qa.mixed;
+
+import org.elasticsearch.test.cluster.ElasticsearchCluster;
+import org.elasticsearch.test.cluster.local.distribution.DistributionType;
+import org.elasticsearch.test.cluster.util.Version;
+
+public class MixedClustersSpec {
+    public static ElasticsearchCluster mixedVersionCluster() {
+        Version oldVersion = Version.fromString(System.getProperty("tests.old_cluster_version"));
+        return ElasticsearchCluster.local()
+            .distribution(DistributionType.DEFAULT)
+            .withNode(node -> node.version(oldVersion))
+            .withNode(node -> node.version(Version.CURRENT))
+            .setting("xpack.security.enabled", "false")
+            .setting("xpack.license.self_generated.type", "trial")
+            .build();
+    }
+}

+ 223 - 0
x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java

@@ -0,0 +1,223 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.qa.mixed;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.hasEntry;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
+
+public class OpenAIServiceMixedIT extends BaseMixedTestCase {
+
+    private static final String OPEN_AI_EMBEDDINGS_ADDED = "8.12.0";
+    private static final String OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED = "8.13.0";
+    private static final String OPEN_AI_COMPLETIONS_ADDED = "8.14.0";
+    private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
+
+    private static MockWebServer openAiEmbeddingsServer;
+    private static MockWebServer openAiChatCompletionsServer;
+
+    @BeforeClass
+    public static void startWebServer() throws IOException {
+        openAiEmbeddingsServer = new MockWebServer();
+        openAiEmbeddingsServer.start();
+
+        openAiChatCompletionsServer = new MockWebServer();
+        openAiChatCompletionsServer.start();
+    }
+
+    @AfterClass
+    public static void shutdown() {
+        openAiEmbeddingsServer.close();
+        openAiChatCompletionsServer.close();
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testOpenAiEmbeddings() throws IOException {
+        var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
+        assumeTrue("OpenAI embedding service added in " + OPEN_AI_EMBEDDINGS_ADDED, openAiEmbeddingsSupported);
+        assumeTrue(
+            "OpenAI service requires at least " + MINIMUM_SUPPORTED_VERSION,
+            bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
+        );
+
+        final String inferenceId = "mixed-cluster-embeddings";
+
+        String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig();
+        // queue a response as PUT will call the service
+        openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
+        put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);
+
+        var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
+        assertThat(configs, hasSize(1));
+        assertEquals("openai", configs.get(0).get("service"));
+        var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
+        var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
+        var modelIdFound = serviceSettings.containsKey("model_id") || taskSettings.containsKey("model_id");
+        assertTrue("model_id not found in config: " + configs.toString(), modelIdFound);
+
+        assertEmbeddingInference(inferenceId);
+    }
+
+    void assertEmbeddingInference(String inferenceId) throws IOException {
+        openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
+        var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
+        assertThat(inferenceMap.entrySet(), not(empty()));
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testOpenAiCompletions() throws IOException {
+        var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
+        assumeTrue("OpenAI completions service added in " + OPEN_AI_COMPLETIONS_ADDED, openAiEmbeddingsSupported);
+        assumeTrue(
+            "OpenAI service requires at least " + MINIMUM_SUPPORTED_VERSION,
+            bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
+        );
+
+        final String inferenceId = "mixed-cluster-completions";
+        final String upgradedClusterId = "upgraded-cluster-completions";
+
+        put(inferenceId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION);
+
+        var configsMap = get(TaskType.COMPLETION, inferenceId);
+        logger.warn("Configs: {}", configsMap);
+        var configs = (List<Map<String, Object>>) configsMap.get("endpoints");
+        assertThat(configs, hasSize(1));
+        assertEquals("openai", configs.get(0).get("service"));
+        var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
+        assertThat(serviceSettings, hasEntry("model_id", "gpt-4"));
+        var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
+        assertThat(taskSettings.keySet(), empty());
+
+        assertCompletionInference(inferenceId);
+    }
+
+    void assertCompletionInference(String inferenceId) throws IOException {
+        openAiChatCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionsResponse()));
+        var inferenceMap = inference(inferenceId, TaskType.COMPLETION, "some text");
+        assertThat(inferenceMap.entrySet(), not(empty()));
+    }
+
+    private String oldClusterVersionCompatibleEmbeddingConfig() {
+        if (getOldClusterTestVersion().before(OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED)) {
+            return embeddingConfigWithModelInTaskSettings(getUrl(openAiEmbeddingsServer));
+        } else {
+            return embeddingConfigWithModelInServiceSettings(getUrl(openAiEmbeddingsServer));
+        }
+    }
+
+    protected static org.elasticsearch.test.cluster.util.Version getOldClusterTestVersion() {
+        return org.elasticsearch.test.cluster.util.Version.fromString(bwcVersion.toString());
+    }
+
+    private String embeddingConfigWithModelInTaskSettings(String url) {
+        return Strings.format("""
+            {
+                "service": "openai",
+                "service_settings": {
+                    "api_key": "XXXX",
+                    "url": "%s"
+                },
+                "task_settings": {
+                   "model": "text-embedding-ada-002"
+                }
+            }
+            """, url);
+    }
+
+    static String embeddingConfigWithModelInServiceSettings(String url) {
+        return Strings.format("""
+            {
+                "service": "openai",
+                "service_settings": {
+                    "api_key": "XXXX",
+                    "url": "%s",
+                    "model_id": "text-embedding-ada-002"
+                }
+            }
+            """, url);
+    }
+
+    private String chatCompletionsConfig(String url) {
+        return Strings.format("""
+            {
+                "service": "openai",
+                "service_settings": {
+                    "api_key": "XXXX",
+                    "url": "%s",
+                    "model_id": "gpt-4"
+                }
+            }
+            """, url);
+    }
+
+    static String embeddingResponse() {
+        return """
+            {
+              "object": "list",
+              "data": [
+                  {
+                      "object": "embedding",
+                      "index": 0,
+                      "embedding": [
+                          0.0123,
+                          -0.0123
+                      ]
+                  }
+              ],
+              "model": "text-embedding-ada-002",
+              "usage": {
+                  "prompt_tokens": 8,
+                  "total_tokens": 8
+              }
+            }
+            """;
+    }
+
+    private String chatCompletionsResponse() {
+        return """
+            {
+              "id": "some-id",
+              "object": "chat.completion",
+              "created": 1705397787,
+              "model": "gpt-3.5-turbo-0613",
+              "choices": [
+                {
+                  "index": 0,
+                  "message": {
+                    "role": "assistant",
+                    "content": "some content"
+                  },
+                  "logprobs": null,
+                  "finish_reason": "stop"
+                }
+              ],
+              "usage": {
+                "prompt_tokens": 46,
+                "completion_tokens": 39,
+                "total_tokens": 85
+              },
+              "system_fingerprint": null
+            }
+            """;
+    }
+
+}