浏览代码

[ML] Validate trained model deployment queue_capacity limit (#89573)

When starting a trained model deployment, a queue is created.
If the queue_capacity is too large, it can lead to OOM and a node
crash.

This commit adds validation that the queue_capacity cannot be more
than 1M.

Closes #89555
Dimitris Athanasiou 3 年之前
父节点
当前提交
32d512286d

+ 6 - 0
docs/changelog/89573.yaml

@@ -0,0 +1,6 @@
+pr: 89573
+summary: Validate trained model deployment `queue_capacity` limit
+area: Machine Learning
+type: bug
+issues:
+ - 89555

+ 1 - 1
docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc

@@ -71,7 +71,7 @@ Defaults to 1.
 Controls how many inference requests are allowed in the queue at a time.
 Every machine learning node in the cluster where the model can be allocated
 has a queue of this size; when the number of requests exceeds the total value,
-new requests are rejected with a 429 error. Defaults to 1024.
+new requests are rejected with a 429 error. Defaults to 1024. Max allowed value is 1000000.
 
 `threads_per_allocation`::
 (Optional, integer)

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -71,6 +71,10 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             AllocationStatus.State.FULLY_ALLOCATED };
 
         private static final int MAX_THREADS_PER_ALLOCATION = 32;
+        /**
+         * If the queue is created then we can OOM when we create the queue.
+         */
+        private static final int MAX_QUEUE_CAPACITY = 1_000_000;
 
         public static final ParseField MODEL_ID = new ParseField("model_id");
         public static final ParseField TIMEOUT = new ParseField("timeout");
@@ -248,6 +252,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             if (queueCapacity < 1) {
                 validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
             }
+            if (queueCapacity > MAX_QUEUE_CAPACITY) {
+                validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be less than " + MAX_QUEUE_CAPACITY);
+            }
             return validationException.validationErrors().isEmpty() ? null : validationException;
         }
 

+ 20 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java

@@ -59,7 +59,7 @@ public class StartTrainedModelDeploymentRequestTests extends AbstractSerializing
             request.setNumberOfAllocations(randomIntBetween(1, 8));
         }
         if (randomBoolean()) {
-            request.setQueueCapacity(randomIntBetween(1, 10000));
+            request.setQueueCapacity(randomIntBetween(1, 1000000));
         }
         return request;
     }
@@ -150,6 +150,25 @@ public class StartTrainedModelDeploymentRequestTests extends AbstractSerializing
         assertThat(e.getMessage(), containsString("[queue_capacity] must be a positive integer"));
     }
 
+    public void testValidate_GivenQueueCapacityIsAtLimit() {
+        Request request = createRandom();
+        request.setQueueCapacity(1_000_000);
+
+        ActionRequestValidationException e = request.validate();
+
+        assertThat(e, is(nullValue()));
+    }
+
+    public void testValidate_GivenQueueCapacityIsOverLimit() {
+        Request request = createRandom();
+        request.setQueueCapacity(1_000_001);
+
+        ActionRequestValidationException e = request.validate();
+
+        assertThat(e, is(not(nullValue())));
+        assertThat(e.getMessage(), containsString("[queue_capacity] must be less than 1000000"));
+    }
+
     public void testDefaults() {
         Request request = new Request(randomAlphaOfLength(10));
         assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));