فهرست منبع

[ML] Use low priority deployments in trained models tests (#95490)

David Kyle 2 سال پیش
والد
کامیت
7e2ec7779a

+ 8 - 5
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MultipleDeploymentsIT.java

@@ -10,6 +10,8 @@ package org.elasticsearch.xpack.ml.integration;
 import org.elasticsearch.client.Response;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.core.Tuple;
+import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
+import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
 import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
 import java.io.IOException;
@@ -28,13 +30,13 @@ public class MultipleDeploymentsIT extends PyTorchModelRestTestCase {
         putAllModelParts(baseModelId);
 
         String forSearch = "for-search";
-        startWithDeploymentId(baseModelId, forSearch);
+        startDeployment(baseModelId, forSearch, AllocationStatus.State.STARTED, 1, 1, Priority.LOW);
 
         Response inference = infer("my words", forSearch);
         assertOK(inference);
 
         String forIngest = "for-ingest";
-        startWithDeploymentId(baseModelId, forIngest);
+        startDeployment(baseModelId, forIngest, AllocationStatus.State.STARTED, 1, 1, Priority.LOW);
 
         inference = infer("my words", forIngest);
         assertOK(inference);
@@ -71,12 +73,13 @@ public class MultipleDeploymentsIT extends PyTorchModelRestTestCase {
         String modelWith2Deployments = "model-with-2-deployments";
         putAllModelParts(modelWith2Deployments);
         String forSearchDeployment = "for-search";
-        startWithDeploymentId(modelWith2Deployments, forSearchDeployment);
+        startDeployment(modelWith2Deployments, forSearchDeployment, AllocationStatus.State.STARTED, 1, 1, Priority.LOW);
+
         String forIngestDeployment = "for-ingest";
-        startWithDeploymentId(modelWith2Deployments, forIngestDeployment);
+        startDeployment(modelWith2Deployments, forIngestDeployment, AllocationStatus.State.STARTED, 1, 1, Priority.LOW);
 
         // deployment Id is the same as model
-        startDeployment(modelWith1Deployment);
+        startDeployment(modelWith1Deployment, modelWith1Deployment, AllocationStatus.State.STARTED, 1, 1, Priority.LOW);
 
         {
             Map<String, Object> stats = entityAsMap(getTrainedModelStats("_all"));

+ 15 - 15
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -190,7 +190,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         putModelDefinition(modelStarted);
 
         CheckedBiConsumer<String, AllocationStatus.State, IOException> assertAtLeast = (modelId, state) -> {
-            startDeployment(modelId, state.toString());
+            startDeployment(modelId, state);
             Response response = getTrainedModelStats(modelId);
             var responseMap = entityAsMap(response);
             List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
@@ -246,7 +246,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         createPassThroughModel(modelId);
         putVocabulary(List.of("once", "twice"), modelId);
         putModelDefinition(modelId);
-        startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
+        startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED);
         {
             Response noInferenceCallsStatsResponse = getTrainedModelStats(modelId);
             List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(noInferenceCallsStatsResponse).get(
@@ -315,7 +315,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         request.setJsonEntity(Strings.format("""
             {"total_definition_length":%s,"definition": "%s","total_parts": 1}""", length, poorlyFormattedModelBase64));
         client().performRequest(request);
-        startDeployment(badModel, AllocationStatus.State.STARTING.toString());
+        startDeployment(badModel, AllocationStatus.State.STARTING);
         assertBusy(() -> {
             Response noInferenceCallsStatsResponse = getTrainedModelStats(badModel);
             List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(noInferenceCallsStatsResponse).get(
@@ -340,8 +340,8 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         putVocabulary(List.of("once", "twice"), modelBar);
         putModelDefinition(modelBar);
 
-        startDeployment(modelFoo, AllocationStatus.State.FULLY_ALLOCATED.toString());
-        startDeployment(modelBar, AllocationStatus.State.FULLY_ALLOCATED.toString());
+        startDeployment(modelFoo, AllocationStatus.State.FULLY_ALLOCATED);
+        startDeployment(modelBar, AllocationStatus.State.FULLY_ALLOCATED);
         infer("once", modelFoo);
         infer("once", modelBar);
         {
@@ -372,8 +372,8 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         putVocabulary(List.of("once", "twice"), modelBar);
         putModelDefinition(modelBar);
 
-        startDeployment(modelFoo, AllocationStatus.State.FULLY_ALLOCATED.toString());
-        startDeployment(modelBar, AllocationStatus.State.FULLY_ALLOCATED.toString());
+        startDeployment(modelFoo, AllocationStatus.State.FULLY_ALLOCATED);
+        startDeployment(modelBar, AllocationStatus.State.FULLY_ALLOCATED);
         infer("once", modelFoo);
         infer("once", modelBar);
 
@@ -447,7 +447,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
             client().performRequest(clusterSettings);
         }
 
-        startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
+        startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED);
 
         List<String> inputs = List.of(
             "my words",
@@ -614,7 +614,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
 
         putVocabulary(List.of("once", "twice", "thrice"), modelId);
         putModelDefinition(modelId);
-        startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
+        startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED);
 
         String input = "once twice thrice";
         var e = expectThrows(ResponseException.class, () -> EntityUtils.toString(infer("once twice thrice", modelId).getEntity()));
@@ -846,8 +846,8 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         putModelDefinition(modelId2);
         putVocabulary(List.of("these", "are", "my", "words"), modelId2);
 
-        startDeployment(modelId1, modelId1, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL);
-        startDeployment(modelId2, modelId2, AllocationStatus.State.STARTING.toString(), 1, 1, Priority.NORMAL);
+        startDeployment(modelId1, modelId1, AllocationStatus.State.STARTED, 100, 1, Priority.NORMAL);
+        startDeployment(modelId2, modelId2, AllocationStatus.State.STARTING, 1, 1, Priority.NORMAL);
 
         // Check second model did not get any allocations
         assertAllocationCount(modelId2, 0);
@@ -888,7 +888,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
 
         ResponseException ex = expectThrows(
             ResponseException.class,
-            () -> startDeployment(modelId, modelId, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL)
+            () -> startDeployment(modelId, modelId, AllocationStatus.State.STARTED, 100, 1, Priority.NORMAL)
         );
         assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(429));
         assertThat(
@@ -924,7 +924,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         putModelDefinition(modelId2);
         putVocabulary(List.of("these", "are", "my", "words"), modelId2);
 
-        startDeployment(modelId1, modelId1, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL);
+        startDeployment(modelId1, modelId1, AllocationStatus.State.STARTED, 100, 1, Priority.NORMAL);
 
         {
             Request request = new Request(
@@ -1033,7 +1033,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         createPassThroughModel(modelId);
         putModelDefinition(modelId);
         putVocabulary(List.of("these", "are", "my", "words"), modelId);
-        startDeployment(modelId, modelId, "started", 2, 1, Priority.NORMAL);
+        startDeployment(modelId, modelId, AllocationStatus.State.STARTED, 2, 1, Priority.NORMAL);
 
         assertBusy(() -> assertAllocationCount(modelId, 2));
 
@@ -1051,7 +1051,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
             createPassThroughModel(modelId);
             putModelDefinition(modelId);
             putVocabulary(List.of("these", "are", "my", "words"), modelId);
-            startDeployment(modelId, modelId, "started", 1, 1, Priority.LOW);
+            startDeployment(modelId, modelId, AllocationStatus.State.STARTED, 1, 1, Priority.LOW);
             assertAllocationCount(modelId, 1);
         }
     }

+ 4 - 4
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

@@ -212,21 +212,21 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
     }
 
     protected Response startDeployment(String modelId) throws IOException {
-        return startDeployment(modelId, AllocationStatus.State.STARTED.toString());
+        return startDeployment(modelId, AllocationStatus.State.STARTED);
     }
 
     protected Response startWithDeploymentId(String modelId, String deploymentId) throws IOException {
-        return startDeployment(modelId, deploymentId, AllocationStatus.State.STARTED.toString(), 1, 1, Priority.NORMAL);
+        return startDeployment(modelId, deploymentId, AllocationStatus.State.STARTED, 1, 1, Priority.NORMAL);
     }
 
-    protected Response startDeployment(String modelId, String waitForState) throws IOException {
+    protected Response startDeployment(String modelId, AllocationStatus.State waitForState) throws IOException {
         return startDeployment(modelId, null, waitForState, 1, 1, Priority.NORMAL);
     }
 
     protected Response startDeployment(
         String modelId,
         String deploymentId,
-        String waitForState,
+        AllocationStatus.State waitForState,
         int numberOfAllocations,
         int threadsPerAllocation,
         Priority priority