|
@@ -14,6 +14,8 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.
|
|
|
|
|
|
import java.io.IOException;
|
|
|
|
|
|
+import static org.hamcrest.Matchers.equalTo;
|
|
|
+
|
|
|
public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializingTestCase<TaskParams> {
|
|
|
|
|
|
@Override
|
|
@@ -32,12 +34,33 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ
|
|
|
}
|
|
|
|
|
|
public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
|
|
|
+ boolean canModelThreadsBeGreaterThanOne = randomBoolean();
|
|
|
+ boolean canInferenceThreadsBeGreaterThanOne = canModelThreadsBeGreaterThanOne == false;
|
|
|
+
|
|
|
return new TaskParams(
|
|
|
randomAlphaOfLength(10),
|
|
|
randomNonNegativeLong(),
|
|
|
- randomIntBetween(1, 8),
|
|
|
- randomIntBetween(1, 8),
|
|
|
+ canInferenceThreadsBeGreaterThanOne ? randomIntBetween(1, 8) : 1,
|
|
|
+ canModelThreadsBeGreaterThanOne ? randomIntBetween(1, 8) : 1,
|
|
|
randomIntBetween(1, 10000)
|
|
|
);
|
|
|
}
|
|
|
+
|
|
|
+ public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndMoreModelThreads() {
|
|
|
+ TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 4, 5, randomIntBetween(1, 10000));
|
|
|
+ assertThat(taskParams.getModelThreads(), equalTo(9));
|
|
|
+ assertThat(taskParams.getInferenceThreads(), equalTo(1));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndMoreInferenceThreads() {
|
|
|
+ TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 3, 2, randomIntBetween(1, 10000));
|
|
|
+ assertThat(taskParams.getModelThreads(), equalTo(1));
|
|
|
+ assertThat(taskParams.getInferenceThreads(), equalTo(5));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndTie() {
|
|
|
+ TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 4, 4, randomIntBetween(1, 10000));
|
|
|
+ assertThat(taskParams.getModelThreads(), equalTo(8));
|
|
|
+ assertThat(taskParams.getInferenceThreads(), equalTo(1));
|
|
|
+ }
|
|
|
}
|