Browse Source

[ML] Trained Model Deployment upgrade tests (#85402)

Full cluster restart and rolling upgrade tests for trained model deployments
David Kyle 3 years ago
parent
commit
760cb929a4

+ 0 - 8
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

@@ -35,14 +35,6 @@ import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.not;
 
-/**
- * This test uses a mixture of HLRC and server side classes.
- *
- * The server classes have builders that set the one-time fields that
- * can only be set on creation e.g. create_time. The HLRC classes must
- * be used when creating PUT trained model requests as they do not set
- * these one-time fields.
- */
 public class TrainedModelIT extends ESRestTestCase {
 
     private static final String BASIC_AUTH_VALUE = UsernamePasswordToken.basicAuthHeaderValue(

+ 203 - 0
x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java

@@ -0,0 +1,203 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.restart;
+
+import org.apache.http.util.EntityUtils;
+import org.elasticsearch.Version;
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.Response;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.common.xcontent.support.XContentMapValues;
+import org.elasticsearch.upgrades.AbstractFullClusterRestartTestCase;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+
+public class MLModelDeploymentFullClusterRestartIT extends AbstractFullClusterRestartTestCase {
+
+    // See PyTorchModelIT for how this model was created
+    static final String BASE_64_ENCODED_MODEL =
+        "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp"
+            + "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA"
+            + "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW"
+            + "lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh"
+            + "kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele"
+            + "s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k"
+            + "umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ"
+            + "wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa"
+            + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq"
+            + "+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7"
+            + "ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3"
+            + "FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28"
+            + "UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw"
+            + "vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW"
+            + "9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0"
+            + "Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts"
+            + "UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs"
+            + "BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn"
+            + "VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU"
+            + "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe"
+            + "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE"
+            + "AAJIEAAAAAA==";
+    static final long RAW_MODEL_SIZE; // size of the model before base64 encoding
+    static {
+        RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
+    }
+
+    @Before
+    public void setLogging() throws IOException {
+        Request loggingSettings = new Request("PUT", "_cluster/settings");
+        loggingSettings.setJsonEntity("""
+            {"persistent" : {
+                    "logger.org.elasticsearch.xpack.ml.inference.allocation" : "TRACE",
+                    "logger.org.elasticsearch.xpack.ml.inference.deployment" : "TRACE",
+                    "logger.org.elasticsearch.xpack.ml.process.logging" : "TRACE"
+                }}""");
+        client().performRequest(loggingSettings);
+    }
+
+    @Override
+    protected Settings restClientSettings() {
+        String token = "Basic " + Base64.getEncoder().encodeToString("test_user:x-pack-test-password".getBytes(StandardCharsets.UTF_8));
+        return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
+    }
+
+    public void testDeploymentSurvivesRestart() throws Exception {
+        assumeTrue("NLP model deployments added in 8.0", getOldClusterVersion().onOrAfter(Version.V_8_0_0));
+
+        String modelId = "trained-model-full-cluster-restart";
+
+        if (isRunningAgainstOldCluster()) {
+            createTrainedModel(modelId);
+            putModelDefinition(modelId);
+            putVocabulary(List.of("these", "are", "my", "words"), modelId);
+            startDeployment(modelId);
+            assertInfer(modelId);
+        } else {
+            waitForDeploymentStarted(modelId);
+            assertInfer(modelId);
+            stopDeployment(modelId);
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    private void waitForDeploymentStarted(String modelId) throws Exception {
+        assertBusy(() -> {
+            var response = getTrainedModelStats(modelId);
+            Map<String, Object> map = entityAsMap(response);
+            List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("trained_model_stats");
+            assertThat(stats, hasSize(1));
+            var stat = stats.get(0);
+            assertThat(
+                stat.toString(),
+                XContentMapValues.extractValue("deployment_stats.allocation_status.state", stat),
+                equalTo("fully_allocated")
+            );
+            assertThat(stat.toString(), XContentMapValues.extractValue("deployment_stats.state", stat), equalTo("started"));
+        }, 30, TimeUnit.SECONDS);
+    }
+
+    private void assertInfer(String modelId) throws IOException {
+        Response inference = infer("my words", modelId);
+        assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"predicted_value\":[[1.0,1.0]]}"));
+    }
+
+    private void putModelDefinition(String modelId) throws IOException {
+        Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
+        request.setJsonEntity("""
+            {"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL));
+        client().performRequest(request);
+    }
+
+    private void putVocabulary(List<String> vocabulary, String modelId) throws IOException {
+        List<String> vocabularyWithPad = new ArrayList<>();
+        vocabularyWithPad.add("[PAD]");
+        vocabularyWithPad.add("[UNK]");
+        vocabularyWithPad.addAll(vocabulary);
+        String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));
+
+        Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary");
+        request.setJsonEntity("""
+            { "vocabulary": [%s] }
+            """.formatted(quotedWords));
+        client().performRequest(request);
+    }
+
+    private void createTrainedModel(String modelId) throws IOException {
+        Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
+        request.setJsonEntity("""
+            {
+               "description": "simple model for testing",
+               "model_type": "pytorch",
+               "inference_config": {
+                 "pass_through": {
+                   "tokenization": {
+                     "bert": {
+                       "with_special_tokens": false
+                     }
+                   }
+                 }
+               }
+             }""");
+        client().performRequest(request);
+    }
+
+    private Response startDeployment(String modelId) throws IOException {
+        return startDeployment(modelId, AllocationStatus.State.STARTED.toString());
+    }
+
+    private Response startDeployment(String modelId, String waitForState) throws IOException {
+        Request request = new Request(
+            "POST",
+            "/_ml/trained_models/"
+                + modelId
+                + "/deployment/_start?timeout=40s&wait_for="
+                + waitForState
+                + "&inference_threads=1&model_threads=1"
+        );
+        var response = client().performRequest(request);
+        assertOK(response);
+        return response;
+    }
+
+    private void stopDeployment(String modelId) throws IOException {
+        String endpoint = "/_ml/trained_models/" + modelId + "/deployment/_stop";
+        Request request = new Request("POST", endpoint);
+        client().performRequest(request);
+    }
+
+    private Response getTrainedModelStats(String modelId) throws IOException {
+        Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats");
+        var response = client().performRequest(request);
+        assertOK(response);
+        return response;
+    }
+
+    private Response infer(String input, String modelId) throws IOException {
+        Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer");
+        request.setJsonEntity("""
+            {  "docs": [{"input":"%s"}] }
+            """.formatted(input));
+
+        var response = client().performRequest(request);
+        assertOK(response);
+        return response;
+    }
+}

+ 202 - 0
x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java

@@ -0,0 +1,202 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.upgrades;
+
+import org.apache.http.util.EntityUtils;
+import org.elasticsearch.Version;
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.Response;
+import org.elasticsearch.common.xcontent.support.XContentMapValues;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+
+public class MLModelDeploymentsUpgradeIT extends AbstractUpgradeTestCase {
+
+    // See PyTorchModelIT for how this model was created
+    static final String BASE_64_ENCODED_MODEL =
+        "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp"
+            + "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA"
+            + "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW"
+            + "lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh"
+            + "kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele"
+            + "s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k"
+            + "umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ"
+            + "wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa"
+            + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq"
+            + "+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7"
+            + "ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3"
+            + "FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28"
+            + "UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw"
+            + "vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW"
+            + "9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0"
+            + "Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts"
+            + "UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs"
+            + "BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn"
+            + "VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU"
+            + "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe"
+            + "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE"
+            + "AAJIEAAAAAA==";
+    static final long RAW_MODEL_SIZE; // size of the model before base64 encoding
+    static {
+        RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
+    }
+
+    public void testTrainedModelDeployment() throws Exception {
+        assumeTrue("NLP model deployments added in 8.0", UPGRADE_FROM_VERSION.onOrAfter(Version.V_8_0_0));
+
+        final String modelId = "upgrade-deployment-test";
+
+        switch (CLUSTER_TYPE) {
+            case OLD -> {
+                setupDeployment(modelId);
+                assertInfer(modelId);
+            }
+            case MIXED -> {
+                ensureHealth(".ml-inference-*,.ml-config*", (request -> {
+                    request.addParameter("wait_for_status", "yellow");
+                    request.addParameter("timeout", "70s");
+                }));
+
+                waitForDeploymentStarted(modelId);
+                assertInfer(modelId);
+                assertInfer(modelId);
+            }
+            case UPGRADED -> {
+                ensureHealth(".ml-inference-*,.ml-config*", (request -> {
+                    request.addParameter("wait_for_status", "yellow");
+                    request.addParameter("timeout", "70s");
+                }));
+
+                waitForDeploymentStarted(modelId);
+                assertInfer(modelId);
+                stopDeployment(modelId);
+            }
+            default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]");
+        }
+    }
+
+    private void setupDeployment(String modelId) throws IOException {
+        createTrainedModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+    }
+
+    @SuppressWarnings("unchecked")
+    private void waitForDeploymentStarted(String modelId) throws Exception {
+        assertBusy(() -> {
+            var response = getTrainedModelStats(modelId);
+            Map<String, Object> map = entityAsMap(response);
+            List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("trained_model_stats");
+            assertThat(stats, hasSize(1));
+            var stat = stats.get(0);
+            assertThat(
+                stat.toString(),
+                XContentMapValues.extractValue("deployment_stats.allocation_status.state", stat),
+                equalTo("fully_allocated")
+            );
+            assertThat(stat.toString(), XContentMapValues.extractValue("deployment_stats.state", stat), equalTo("started"));
+        }, 30, TimeUnit.SECONDS);
+    }
+
+    private void assertInfer(String modelId) throws IOException {
+        Response inference = infer("my words", modelId);
+        assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"predicted_value\":[[1.0,1.0]]}"));
+    }
+
+    private void putModelDefinition(String modelId) throws IOException {
+        Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
+        request.setJsonEntity("""
+            {"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL));
+        client().performRequest(request);
+    }
+
+    private void putVocabulary(List<String> vocabulary, String modelId) throws IOException {
+        List<String> vocabularyWithPad = new ArrayList<>();
+        vocabularyWithPad.add("[PAD]");
+        vocabularyWithPad.add("[UNK]");
+        vocabularyWithPad.addAll(vocabulary);
+        String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));
+
+        Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary");
+        request.setJsonEntity("""
+            { "vocabulary": [%s] }
+            """.formatted(quotedWords));
+        client().performRequest(request);
+    }
+
+    private void createTrainedModel(String modelId) throws IOException {
+        Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
+        request.setJsonEntity("""
+            {
+               "description": "simple model for testing",
+               "model_type": "pytorch",
+               "inference_config": {
+                 "pass_through": {
+                   "tokenization": {
+                     "bert": {
+                       "with_special_tokens": false
+                     }
+                   }
+                 }
+               }
+             }""");
+        client().performRequest(request);
+    }
+
+    private Response startDeployment(String modelId) throws IOException {
+        return startDeployment(modelId, "started");
+    }
+
+    private Response startDeployment(String modelId, String waitForState) throws IOException {
+        Request request = new Request(
+            "POST",
+            "/_ml/trained_models/"
+                + modelId
+                + "/deployment/_start?timeout=40s&wait_for="
+                + waitForState
+                + "&inference_threads=1&model_threads=1"
+        );
+        var response = client().performRequest(request);
+        assertOK(response);
+        return response;
+    }
+
+    private void stopDeployment(String modelId) throws IOException {
+        String endpoint = "/_ml/trained_models/" + modelId + "/deployment/_stop";
+        Request request = new Request("POST", endpoint);
+        client().performRequest(request);
+    }
+
+    private Response getTrainedModelStats(String modelId) throws IOException {
+        Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats");
+        var response = client().performRequest(request);
+        assertOK(response);
+        return response;
+    }
+
+    private Response infer(String input, String modelId) throws IOException {
+        Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer");
+        request.setJsonEntity("""
+            {  "docs": [{"input":"%s"}] }
+            """.formatted(input));
+
+        var response = client().performRequest(request);
+        assertOK(response);
+        return response;
+    }
+}