1
0
Эх сурвалжийг харах

[ML] Require that threads_per_allocation is a power of 2 (#87697)

As the number of cores in CPUs is typically a power of 2,
this commit adds a validation that trained model deployments
start with `threads_per_allocation` set to be a power of 2.
When we look for how we distribute the allocations across the
cluster, this prevents situations where we have a lot of wasted
CPU cores.

In addition, we add a max value limit of `32`.
Dimitris Athanasiou 3 жил өмнө
parent
commit
679351e224

+ 1 - 1
docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc

@@ -70,7 +70,7 @@ the inference speed. The inference process is a compute-bound process; any numbe
 greater than the number of available hardware threads on the machine does not increase the
 inference speed. If this setting is greater than the number of hardware threads
 it will automatically be changed to a value less than the number of hardware threads.
-Defaults to 1.
+Defaults to 1. Must be a power of 2. Max allowed value is 32.
 
 `timeout`::
 (Optional, time)

+ 12 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -66,6 +66,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             AllocationStatus.State.STARTED,
             AllocationStatus.State.STARTING,
             AllocationStatus.State.FULLY_ALLOCATED };
+
+        private static final int MAX_THREADS_PER_ALLOCATION = 32;
+
         public static final ParseField MODEL_ID = new ParseField("model_id");
         public static final ParseField TIMEOUT = new ParseField("timeout");
         public static final ParseField WAIT_FOR = new ParseField("wait_for");
@@ -209,12 +212,21 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             if (threadsPerAllocation < 1) {
                 validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer");
             }
+            if (threadsPerAllocation > MAX_THREADS_PER_ALLOCATION || isPowerOf2(threadsPerAllocation) == false) {
+                validationException.addValidationError(
+                    "[" + THREADS_PER_ALLOCATION + "] must be a power of 2 less than or equal to " + MAX_THREADS_PER_ALLOCATION
+                );
+            }
             if (queueCapacity < 1) {
                 validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
             }
             return validationException.validationErrors().isEmpty() ? null : validationException;
         }
 
+        private static boolean isPowerOf2(int value) {
+            return Integer.bitCount(value) == 1;
+        }
+
         @Override
         public int hashCode() {
             return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);

+ 30 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java

@@ -16,6 +16,10 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.
 import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
 
 import java.io.IOException;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
@@ -80,6 +84,32 @@ public class StartTrainedModelDeploymentRequestTests extends AbstractSerializing
         assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a positive integer"));
     }
 
+    public void testValidate_GivenThreadsPerAllocationIsNotPowerOf2() {
+        Set<Integer> powersOf2 = IntStream.range(0, 10).map(n -> (int) Math.pow(2, n)).boxed().collect(Collectors.toSet());
+        List<Integer> input = IntStream.range(1, 33).filter(n -> powersOf2.contains(n) == false).boxed().toList();
+
+        for (int n : input) {
+            Request request = createRandom();
+            request.setThreadsPerAllocation(n);
+
+            ActionRequestValidationException e = request.validate();
+
+            assertThat(e, is(not(nullValue())));
+            assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a power of 2 less than or equal to 32"));
+        }
+    }
+
+    public void testValidate_GivenThreadsPerAllocationIsValid() {
+        for (int n : List.of(1, 2, 4, 8, 16, 32)) {
+            Request request = createRandom();
+            request.setThreadsPerAllocation(n);
+
+            ActionRequestValidationException e = request.validate();
+
+            assertThat(e, is(nullValue()));
+        }
+    }
+
     public void testValidate_GivenNumberOfAllocationsIsZero() {
         Request request = createRandom();
         request.setNumberOfAllocations(0);