Browse Source

[ML] Only one of `inference_threads` and `model_threads` may be great… (#84794)

Starting a trained model deployment the user may set values for `inference_threads`
of `model_threads`. The first improves latency whereas the latter improves throughput.
It is easier to reason on how a model allocation uses resources if we ensure only
one of those two may be greater than one. In addition, it allows us to distribute
the cores of the ML nodes in the cluster across the model allocations in the future.

This commit adds a validation that prevents both `inference_threads` and `model_threads`
to be greater than one.
Dimitris Athanasiou 3 years ago
parent
commit
4eaedb265d

+ 5 - 0
docs/changelog/84794.yaml

@@ -0,0 +1,5 @@
+pr: 84794
+summary: Only one of `inference_threads` and `model_threads` may be greater than one
+area: Machine Learning
+type: bug
+issues: []

+ 3 - 2
docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc

@@ -57,8 +57,9 @@ Defaults to 1.
 
 [NOTE]
 =============================================
-If the sum of `inference_threads` and `model_threads` is greater than the number of
-hardware threads then the number of `inference_threads` will be reduced.
+Either one of `inference_threads` and `model_threads` may be greater than one, not both.
+Increase `inference_threads` to optimize for latency. Increase `model_threads` to optimize
+for throughput.
 =============================================
 
 `queue_capacity`::

+ 17 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.TimeUnit;
 
@@ -205,6 +206,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             if (inferenceThreads < 1) {
                 validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
             }
+            if (modelThreads > 1 && inferenceThreads > 1) {
+                validationException.addValidationError(
+                    "only one of " + List.of(INFERENCE_THREADS, MODEL_THREADS) + " may be greater than 1"
+                );
+            }
             if (queueCapacity < 1) {
                 validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
             }
@@ -283,9 +289,18 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
             this.modelId = Objects.requireNonNull(modelId);
             this.modelBytes = modelBytes;
-            this.inferenceThreads = inferenceThreads;
-            this.modelThreads = modelThreads;
             this.queueCapacity = queueCapacity;
+
+            // From 8.2 onwards, only one of inferenceThreads or modelThreads may be greater than 1.
+            // But for allocations started prior to 8.2, it may be that they're both greater than 1.
+            if (inferenceThreads > 1 && modelThreads > 1) {
+                // We ensure only one of the two is greater than 1.
+                this.modelThreads = modelThreads >= inferenceThreads ? modelThreads + inferenceThreads : 1;
+                this.inferenceThreads = inferenceThreads > modelThreads ? modelThreads + inferenceThreads : 1;
+            } else {
+                this.modelThreads = modelThreads;
+                this.inferenceThreads = inferenceThreads;
+            }
         }
 
         public TaskParams(StreamInput in) throws IOException {

+ 25 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java

@@ -14,6 +14,8 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.
 
 import java.io.IOException;
 
+import static org.hamcrest.Matchers.equalTo;
+
 public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializingTestCase<TaskParams> {
 
     @Override
@@ -32,12 +34,33 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ
     }
 
     public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
+        boolean canModelThreadsBeGreaterThanOne = randomBoolean();
+        boolean canInferenceThreadsBeGreaterThanOne = canModelThreadsBeGreaterThanOne == false;
+
         return new TaskParams(
             randomAlphaOfLength(10),
             randomNonNegativeLong(),
-            randomIntBetween(1, 8),
-            randomIntBetween(1, 8),
+            canInferenceThreadsBeGreaterThanOne ? randomIntBetween(1, 8) : 1,
+            canModelThreadsBeGreaterThanOne ? randomIntBetween(1, 8) : 1,
             randomIntBetween(1, 10000)
         );
     }
+
+    public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndMoreModelThreads() {
+        TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 4, 5, randomIntBetween(1, 10000));
+        assertThat(taskParams.getModelThreads(), equalTo(9));
+        assertThat(taskParams.getInferenceThreads(), equalTo(1));
+    }
+
+    public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndMoreInferenceThreads() {
+        TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 3, 2, randomIntBetween(1, 10000));
+        assertThat(taskParams.getModelThreads(), equalTo(1));
+        assertThat(taskParams.getInferenceThreads(), equalTo(5));
+    }
+
+    public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndTie() {
+        TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 4, 4, randomIntBetween(1, 10000));
+        assertThat(taskParams.getModelThreads(), equalTo(8));
+        assertThat(taskParams.getInferenceThreads(), equalTo(1));
+    }
 }