Przeglądaj źródła

[ML] Adjust thread pool size for inference process for low priority d… (#91605)

In #91234 we introduced low priority deployments so that users can test
model functionality if there is enough memory to load the model even if
CPU resources are used fully. However, we did not adjust the thread pool
size. This meant the thread pool could be depleted resulting to assignments
hanging during their startup phase.
Dimitris Athanasiou 2 lat temu
rodzic
commit
995f578d83

+ 20 - 5
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
+import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
 
 import java.io.IOException;
 import java.util.Base64;
@@ -755,8 +756,8 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         putModelDefinition(modelId2);
         putVocabulary(List.of("these", "are", "my", "words"), modelId2);
 
-        startDeployment(modelId1, AllocationStatus.State.STARTED.toString(), 100, 1);
-        startDeployment(modelId2, AllocationStatus.State.STARTING.toString(), 1, 1);
+        startDeployment(modelId1, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL);
+        startDeployment(modelId2, AllocationStatus.State.STARTING.toString(), 1, 1, Priority.NORMAL);
 
         // Check second model did not get any allocations
         assertAllocationCount(modelId2, 0);
@@ -797,7 +798,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
 
         ResponseException ex = expectThrows(
             ResponseException.class,
-            () -> startDeployment(modelId, AllocationStatus.State.STARTED.toString(), 100, 1)
+            () -> startDeployment(modelId, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL)
         );
         assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(429));
         assertThat(
@@ -833,7 +834,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         putModelDefinition(modelId2);
         putVocabulary(List.of("these", "are", "my", "words"), modelId2);
 
-        startDeployment(modelId1, AllocationStatus.State.STARTED.toString(), 100, 1);
+        startDeployment(modelId1, AllocationStatus.State.STARTED.toString(), 100, 1, Priority.NORMAL);
 
         {
             Request request = new Request(
@@ -942,7 +943,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         createPassThroughModel(modelId);
         putModelDefinition(modelId);
         putVocabulary(List.of("these", "are", "my", "words"), modelId);
-        startDeployment(modelId, "started", 2, 1);
+        startDeployment(modelId, "started", 2, 1, Priority.NORMAL);
 
         assertBusy(() -> assertAllocationCount(modelId, 2));
 
@@ -951,6 +952,20 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         assertBusy(() -> assertAllocationCount(modelId, 1));
     }
 
+    public void testStartMultipleLowPriorityDeployments() throws Exception {
+        String modelId1 = "start_multiple_low_priority_deployments_1";
+        String modelId2 = "start_multiple_low_priority_deployments_2";
+        String modelId3 = "start_multiple_low_priority_deployments_3";
+        String modelId4 = "start_multiple_low_priority_deployments_4";
+        for (String modelId : List.of(modelId1, modelId2, modelId3, modelId4)) {
+            createPassThroughModel(modelId);
+            putModelDefinition(modelId);
+            putVocabulary(List.of("these", "are", "my", "words"), modelId);
+            startDeployment(modelId, "started", 1, 1, Priority.LOW);
+            assertAllocationCount(modelId, 1);
+        }
+    }
+
     private void putModelDefinition(String modelId) throws IOException {
         putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
     }

+ 11 - 3
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

@@ -17,6 +17,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.test.SecuritySettingsSourceField;
 import org.elasticsearch.test.rest.ESRestTestCase;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
+import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
 import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
 import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
@@ -191,11 +192,16 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
     }
 
     protected Response startDeployment(String modelId, String waitForState) throws IOException {
-        return startDeployment(modelId, waitForState, 1, 1);
+        return startDeployment(modelId, waitForState, 1, 1, Priority.NORMAL);
     }
 
-    protected Response startDeployment(String modelId, String waitForState, int numberOfAllocations, int threadsPerAllocation)
-        throws IOException {
+    protected Response startDeployment(
+        String modelId,
+        String waitForState,
+        int numberOfAllocations,
+        int threadsPerAllocation,
+        Priority priority
+    ) throws IOException {
         Request request = new Request(
             "POST",
             "/_ml/trained_models/"
@@ -206,6 +212,8 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
                 + threadsPerAllocation
                 + "&number_of_allocations="
                 + numberOfAllocations
+                + "&priority="
+                + priority
         );
         return client().performRequest(request);
     }

+ 16 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -699,6 +699,12 @@ public class MachineLearning extends Plugin
      */
     public static final int MAX_TRAINED_MODEL_DEPLOYMENTS = 100;
 
+    /**
+     * The number of low priority models each node can host.
+     * Effectively, a value of 100 means the limit is purely based on memory.
+     */
+    public static final int MAX_LOW_PRIORITY_MODELS_PER_NODE = 100;
+
     private static final Logger logger = LogManager.getLogger(MachineLearning.class);
 
     private final Settings settings;
@@ -1041,7 +1047,9 @@ public class MachineLearning extends Plugin
             getLicenseState()
         );
         this.modelLoadingService.set(modelLoadingService);
-        this.deploymentManager.set(new DeploymentManager(client, xContentRegistry, threadPool, pyTorchProcessFactory));
+        this.deploymentManager.set(
+            new DeploymentManager(client, xContentRegistry, threadPool, pyTorchProcessFactory, getMaxModelDeploymentsPerNode())
+        );
 
         // Data frame analytics components
         AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(
@@ -1465,14 +1473,15 @@ public class MachineLearning extends Plugin
 
         // 3 threads per native inference process: for input, c++ logger output, and result processing.
         // As we cannot assign more models than the number of allocated processors, this thread pool's
-        // size is limited by the number of allocated processors on this node.
+        // size is limited by the number of allocated processors on this node. Additionally, we add
+        // the number of low priority model deployments per node.
         // Only use this thread pool for the main long-running process associated with a native inference model deployment.
         // (Using it for some other purpose could mean that an unrelated pytorch model assignment fails to start
         // or that whatever needed the thread for another purpose has to queue for a very long time.)
         ScalingExecutorBuilder pytorchComms = new ScalingExecutorBuilder(
             NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME,
             3,
-            getAllocatedProcessors().roundUp() * 3,
+            getMaxModelDeploymentsPerNode() * 3,
             TimeValue.timeValueMinutes(1),
             false,
             "xpack.ml.native_inference_comms_thread_pool"
@@ -1501,6 +1510,10 @@ public class MachineLearning extends Plugin
         return List.of(jobComms, pytorchComms, utility, datafeed);
     }
 
+    private int getMaxModelDeploymentsPerNode() {
+        return getAllocatedProcessors().roundUp() + MAX_LOW_PRIORITY_MODELS_PER_NODE;
+    }
+
     @Override
     public Map<String, AnalysisProvider<CharFilterFactory>> getCharFilters() {
         return MapBuilder.<String, AnalysisProvider<CharFilterFactory>>newMapBuilder()

+ 1 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

@@ -40,18 +40,12 @@ import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.ml.MachineLearning.MAX_LOW_PRIORITY_MODELS_PER_NODE;
 
 class TrainedModelAssignmentRebalancer {
 
     private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentRebalancer.class);
 
-    /**
-     * We set the max number of low priority models per node to 100,
-     * a value that effectively removes the processor constraint and
-     * transforms the problem to memory bin packing.
-     */
-    private static final int MAX_LOW_PRIORITY_MODELS_PER_NODE = 100;
-
     private final TrainedModelAssignmentMetadata currentMetadata;
     private final Map<DiscoveryNode, NodeLoad> nodeLoads;
     private final Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone;

+ 11 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -77,12 +77,14 @@ public class DeploymentManager {
     private final ExecutorService executorServiceForProcess;
     private final ThreadPool threadPool;
     private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
+    private final int maxProcesses;
 
     public DeploymentManager(
         Client client,
         NamedXContentRegistry xContentRegistry,
         ThreadPool threadPool,
-        PyTorchProcessFactory pyTorchProcessFactory
+        PyTorchProcessFactory pyTorchProcessFactory,
+        int maxProcesses
     ) {
         this.client = Objects.requireNonNull(client);
         this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
@@ -90,6 +92,7 @@ public class DeploymentManager {
         this.threadPool = Objects.requireNonNull(threadPool);
         this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
         this.executorServiceForProcess = threadPool.executor(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME);
+        this.maxProcesses = maxProcesses;
     }
 
     public void startDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> listener) {
@@ -123,6 +126,13 @@ public class DeploymentManager {
 
     // function exposed for testing
     ProcessContext addProcessContext(Long id, ProcessContext processContext) {
+        if (processContextByAllocation.size() >= maxProcesses) {
+            throw ExceptionsHelper.serverError(
+                "[{}] Could not start inference process as the node reached the max number [{}] of processes",
+                processContext.task.getModelId(),
+                maxProcesses
+            );
+        }
         return processContextByAllocation.putIfAbsent(id, processContext);
     }
 

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

@@ -75,7 +75,8 @@ public class DeploymentManagerTests extends ESTestCase {
             mock(Client.class),
             mock(NamedXContentRegistry.class),
             tp,
-            mock(PyTorchProcessFactory.class)
+            mock(PyTorchProcessFactory.class),
+            10
         );
 
         PriorityProcessWorkerExecutorService priorityExecutorService = new PriorityProcessWorkerExecutorService(