|
|
@@ -16,6 +16,10 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.
|
|
|
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Set;
|
|
|
+import java.util.stream.Collectors;
|
|
|
+import java.util.stream.IntStream;
|
|
|
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
@@ -80,6 +84,32 @@ public class StartTrainedModelDeploymentRequestTests extends AbstractSerializing
|
|
|
assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a positive integer"));
|
|
|
}
|
|
|
|
|
|
+ public void testValidate_GivenThreadsPerAllocationIsNotPowerOf2() {
|
|
|
+ Set<Integer> powersOf2 = IntStream.range(0, 10).map(n -> (int) Math.pow(2, n)).boxed().collect(Collectors.toSet());
|
|
|
+ List<Integer> input = IntStream.range(1, 33).filter(n -> powersOf2.contains(n) == false).boxed().toList();
|
|
|
+
|
|
|
+ for (int n : input) {
|
|
|
+ Request request = createRandom();
|
|
|
+ request.setThreadsPerAllocation(n);
|
|
|
+
|
|
|
+ ActionRequestValidationException e = request.validate();
|
|
|
+
|
|
|
+ assertThat(e, is(not(nullValue())));
|
|
|
+ assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a power of 2 less than or equal to 32"));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testValidate_GivenThreadsPerAllocationIsValid() {
|
|
|
+ for (int n : List.of(1, 2, 4, 8, 16, 32)) {
|
|
|
+ Request request = createRandom();
|
|
|
+ request.setThreadsPerAllocation(n);
|
|
|
+
|
|
|
+ ActionRequestValidationException e = request.validate();
|
|
|
+
|
|
|
+ assertThat(e, is(nullValue()));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testValidate_GivenNumberOfAllocationsIsZero() {
|
|
|
Request request = createRandom();
|
|
|
request.setNumberOfAllocations(0);
|