|
@@ -16,17 +16,13 @@ import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
|
|
|
-import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState;
|
|
|
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
|
|
import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
|
|
|
|
|
|
import java.net.InetAddress;
|
|
|
-import java.util.Arrays;
|
|
|
|
|
|
import static org.hamcrest.Matchers.contains;
|
|
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
|
@@ -308,41 +304,6 @@ public class MlTasksTests extends ESTestCase {
|
|
|
assertThat(state, equalTo(DataFrameAnalyticsState.FAILED));
|
|
|
}
|
|
|
|
|
|
- public void testGetTrainedModelDeploymentState_GivenNull() {
|
|
|
- assertThat(MlTasks.getTrainedModelDeploymentState(null), equalTo(TrainedModelDeploymentState.STOPPED));
|
|
|
- }
|
|
|
-
|
|
|
- public void testGetTrainedModelDeploymentState_GivenTaskStateIsNull() {
|
|
|
- PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(null, false);
|
|
|
- assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STARTING));
|
|
|
- }
|
|
|
-
|
|
|
- public void testGetTrainedModelDeploymentState_GivenTaskStateIsNotNullAndNotStale() {
|
|
|
- TrainedModelDeploymentState state = randomFrom(TrainedModelDeploymentState.values());
|
|
|
- PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(state, false);
|
|
|
- assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(state));
|
|
|
- }
|
|
|
-
|
|
|
- public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndStopping() {
|
|
|
- PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(TrainedModelDeploymentState.STOPPING, true);
|
|
|
- assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STOPPED));
|
|
|
- }
|
|
|
-
|
|
|
- public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndFailed() {
|
|
|
- PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(TrainedModelDeploymentState.FAILED, true);
|
|
|
- assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.FAILED));
|
|
|
- }
|
|
|
-
|
|
|
- public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndNotFailedNorStopping() {
|
|
|
- TrainedModelDeploymentState state = randomFrom(
|
|
|
- Arrays.stream(TrainedModelDeploymentState.values())
|
|
|
- .filter(s -> s != TrainedModelDeploymentState.FAILED && s != TrainedModelDeploymentState.STOPPING)
|
|
|
- .toArray(TrainedModelDeploymentState[]::new)
|
|
|
- );
|
|
|
- PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(state, true);
|
|
|
- assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STARTING));
|
|
|
- }
|
|
|
-
|
|
|
private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAnalyticsTask(String jobId, String nodeId,
|
|
|
DataFrameAnalyticsState state,
|
|
|
boolean isStale) {
|
|
@@ -358,18 +319,4 @@ public class MlTasksTests extends ESTestCase {
|
|
|
return tasks.getTask(MlTasks.dataFrameAnalyticsTaskId(jobId));
|
|
|
}
|
|
|
|
|
|
- private static PersistentTasksCustomMetadata.PersistentTask<?> createTrainedModelTask(TrainedModelDeploymentState state,
|
|
|
- boolean isStale) {
|
|
|
- String id = randomAlphaOfLength(10);
|
|
|
- PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder();
|
|
|
- builder.addTask(MlTasks.trainedModelDeploymentTaskId(id), MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME,
|
|
|
- new StartTrainedModelDeploymentAction.TaskParams(id, randomAlphaOfLength(10), randomNonNegativeLong()),
|
|
|
- new PersistentTasksCustomMetadata.Assignment(randomAlphaOfLength(10), "test assignment"));
|
|
|
- if (state != null) {
|
|
|
- builder.updateTaskState(MlTasks.trainedModelDeploymentTaskId(id),
|
|
|
- new TrainedModelDeploymentTaskState(state, builder.getLastAllocationId() - (isStale ? 1 : 0), null));
|
|
|
- }
|
|
|
- PersistentTasksCustomMetadata tasks = builder.build();
|
|
|
- return tasks.getTask(MlTasks.trainedModelDeploymentTaskId(id));
|
|
|
- }
|
|
|
}
|