Browse Source

Revert "[ML] Only one of `inference_threads` and `model_threads` may be great… (#84794)" (#85089)

This reverts commit 4eaedb265dfddd1020e6e39d0797d378b8ce62a1.

On further investigation of how to improve allocation of trained models,
we concluded that being able to set `inference_threads` in combination with
`model_threads` is fundamental for scalability.
Dimitris Athanasiou 3 years ago
parent
commit
5d670e45ac

+ 0 - 5
docs/changelog/84794.yaml

@@ -1,5 +0,0 @@
-pr: 84794
-summary: Only one of `inference_threads` and `model_threads` may be greater than one
-area: Machine Learning
-type: bug
-issues: []

+ 2 - 3
docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc

@@ -57,9 +57,8 @@ Defaults to 1.
 
 [NOTE]
 =============================================
-Either one of `inference_threads` and `model_threads` may be greater than one, not both.
-Increase `inference_threads` to optimize for latency. Increase `model_threads` to optimize
-for throughput.
+If the sum of `inference_threads` and `model_threads` is greater than the number of
+hardware threads then the number of `inference_threads` will be reduced.
 =============================================
 
 `queue_capacity`::

+ 2 - 17
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -33,7 +33,6 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 
 import java.io.IOException;
-import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.TimeUnit;
 
@@ -206,11 +205,6 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             if (inferenceThreads < 1) {
                 validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
             }
-            if (modelThreads > 1 && inferenceThreads > 1) {
-                validationException.addValidationError(
-                    "only one of " + List.of(INFERENCE_THREADS, MODEL_THREADS) + " may be greater than 1"
-                );
-            }
             if (queueCapacity < 1) {
                 validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
             }
@@ -289,18 +283,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
             this.modelId = Objects.requireNonNull(modelId);
             this.modelBytes = modelBytes;
+            this.inferenceThreads = inferenceThreads;
+            this.modelThreads = modelThreads;
             this.queueCapacity = queueCapacity;
-
-            // From 8.2 onwards, only one of inferenceThreads or modelThreads may be greater than 1.
-            // But for allocations started prior to 8.2, it may be that they're both greater than 1.
-            if (inferenceThreads > 1 && modelThreads > 1) {
-                // We ensure only one of the two is greater than 1.
-                this.modelThreads = modelThreads >= inferenceThreads ? modelThreads + inferenceThreads : 1;
-                this.inferenceThreads = inferenceThreads > modelThreads ? modelThreads + inferenceThreads : 1;
-            } else {
-                this.modelThreads = modelThreads;
-                this.inferenceThreads = inferenceThreads;
-            }
         }
 
         public TaskParams(StreamInput in) throws IOException {

+ 2 - 25
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java

@@ -14,8 +14,6 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.
 
 import java.io.IOException;
 
-import static org.hamcrest.Matchers.equalTo;
-
 public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializingTestCase<TaskParams> {
 
     @Override
@@ -34,33 +32,12 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ
     }
 
     public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
-        boolean canModelThreadsBeGreaterThanOne = randomBoolean();
-        boolean canInferenceThreadsBeGreaterThanOne = canModelThreadsBeGreaterThanOne == false;
-
         return new TaskParams(
             randomAlphaOfLength(10),
             randomNonNegativeLong(),
-            canInferenceThreadsBeGreaterThanOne ? randomIntBetween(1, 8) : 1,
-            canModelThreadsBeGreaterThanOne ? randomIntBetween(1, 8) : 1,
+            randomIntBetween(1, 8),
+            randomIntBetween(1, 8),
             randomIntBetween(1, 10000)
         );
     }
-
-    public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndMoreModelThreads() {
-        TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 4, 5, randomIntBetween(1, 10000));
-        assertThat(taskParams.getModelThreads(), equalTo(9));
-        assertThat(taskParams.getInferenceThreads(), equalTo(1));
-    }
-
-    public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndMoreInferenceThreads() {
-        TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 3, 2, randomIntBetween(1, 10000));
-        assertThat(taskParams.getModelThreads(), equalTo(1));
-        assertThat(taskParams.getInferenceThreads(), equalTo(5));
-    }
-
-    public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndTie() {
-        TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 4, 4, randomIntBetween(1, 10000));
-        assertThat(taskParams.getModelThreads(), equalTo(8));
-        assertThat(taskParams.getInferenceThreads(), equalTo(1));
-    }
 }