|
@@ -14,8 +14,6 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.
|
|
|
|
|
|
import java.io.IOException;
|
|
|
|
|
|
-import static org.hamcrest.Matchers.equalTo;
|
|
|
-
|
|
|
public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializingTestCase<TaskParams> {
|
|
|
|
|
|
@Override
|
|
@@ -34,33 +32,12 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ
|
|
|
}
|
|
|
|
|
|
public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
|
|
|
- boolean canModelThreadsBeGreaterThanOne = randomBoolean();
|
|
|
- boolean canInferenceThreadsBeGreaterThanOne = canModelThreadsBeGreaterThanOne == false;
|
|
|
-
|
|
|
return new TaskParams(
|
|
|
randomAlphaOfLength(10),
|
|
|
randomNonNegativeLong(),
|
|
|
- canInferenceThreadsBeGreaterThanOne ? randomIntBetween(1, 8) : 1,
|
|
|
- canModelThreadsBeGreaterThanOne ? randomIntBetween(1, 8) : 1,
|
|
|
+ randomIntBetween(1, 8),
|
|
|
+ randomIntBetween(1, 8),
|
|
|
randomIntBetween(1, 10000)
|
|
|
);
|
|
|
}
|
|
|
-
|
|
|
- public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndMoreModelThreads() {
|
|
|
- TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 4, 5, randomIntBetween(1, 10000));
|
|
|
- assertThat(taskParams.getModelThreads(), equalTo(9));
|
|
|
- assertThat(taskParams.getInferenceThreads(), equalTo(1));
|
|
|
- }
|
|
|
-
|
|
|
- public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndMoreInferenceThreads() {
|
|
|
- TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 3, 2, randomIntBetween(1, 10000));
|
|
|
- assertThat(taskParams.getModelThreads(), equalTo(1));
|
|
|
- assertThat(taskParams.getInferenceThreads(), equalTo(5));
|
|
|
- }
|
|
|
-
|
|
|
- public void testCtor_GivenBothModelAndInferenceThreadsGreaterThanOne_AndTie() {
|
|
|
- TaskParams taskParams = new TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 4, 4, randomIntBetween(1, 10000));
|
|
|
- assertThat(taskParams.getModelThreads(), equalTo(8));
|
|
|
- assertThat(taskParams.getInferenceThreads(), equalTo(1));
|
|
|
- }
|
|
|
}
|