|
@@ -10,6 +10,8 @@ package org.elasticsearch.xpack.ml.inference.allocation;
|
|
|
import org.elasticsearch.ResourceNotFoundException;
|
|
|
import org.elasticsearch.Version;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
+import org.elasticsearch.action.search.SearchPhaseExecutionException;
|
|
|
+import org.elasticsearch.action.search.ShardSearchFailure;
|
|
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
|
|
import org.elasticsearch.cluster.ClusterChangedEvent;
|
|
|
import org.elasticsearch.cluster.ClusterName;
|
|
@@ -30,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata;
|
|
|
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
|
|
|
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
|
|
|
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
|
|
@@ -129,6 +132,38 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
|
|
|
verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService);
|
|
|
}
|
|
|
|
|
|
+ public void testLoadQueuedModelsWhenFailureIsRetried() {
|
|
|
+ String modelToLoad = "loading-model";
|
|
|
+ String failedModelToLoad = "failed-search-loading-model";
|
|
|
+ withSearchingLoadFailure(failedModelToLoad);
|
|
|
+ TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService();
|
|
|
+
|
|
|
+ trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad));
|
|
|
+ trainedModelAllocationNodeService.prepareModelToLoad(newParams(failedModelToLoad));
|
|
|
+
|
|
|
+ trainedModelAllocationNodeService.loadQueuedModels();
|
|
|
+
|
|
|
+ trainedModelAllocationNodeService.loadQueuedModels();
|
|
|
+
|
|
|
+ ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
|
|
|
+ ArgumentCaptor<UpdateTrainedModelAllocationStateAction.Request> requestCapture = ArgumentCaptor.forClass(
|
|
|
+ UpdateTrainedModelAllocationStateAction.Request.class
|
|
|
+ );
|
|
|
+ verify(deploymentManager, times(3)).startDeployment(startTaskCapture.capture(), any());
|
|
|
+ // Only the successful one is notifying, the failed one keeps retrying but not notifying as it is never successful
|
|
|
+ verify(trainedModelAllocationService, times(1)).updateModelAllocationState(requestCapture.capture(), any());
|
|
|
+
|
|
|
+ assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad));
|
|
|
+ assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad));
|
|
|
+ assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID));
|
|
|
+ assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED));
|
|
|
+
|
|
|
+ assertThat(startTaskCapture.getAllValues().get(1).getModelId(), equalTo(failedModelToLoad));
|
|
|
+ assertThat(startTaskCapture.getAllValues().get(2).getModelId(), equalTo(failedModelToLoad));
|
|
|
+
|
|
|
+ verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService);
|
|
|
+ }
|
|
|
+
|
|
|
public void testLoadQueuedModelsWhenStopped() {
|
|
|
TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService();
|
|
|
|
|
@@ -304,6 +339,7 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
|
|
|
String modelOne = "model-1";
|
|
|
String modelTwo = "model-2";
|
|
|
String notUsedModel = "model-3";
|
|
|
+ String previouslyUsedModel = "model-4";
|
|
|
ClusterChangedEvent event = new ClusterChangedEvent(
|
|
|
"testClusterChanged",
|
|
|
ClusterState.builder(new ClusterName("testClusterChanged"))
|
|
@@ -319,7 +355,28 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
|
|
|
)
|
|
|
.addNewAllocation(
|
|
|
modelTwo,
|
|
|
- TrainedModelAllocation.Builder.empty(newParams(modelTwo)).addNewRoutingEntry(NODE_ID)
|
|
|
+ TrainedModelAllocation.Builder
|
|
|
+ .empty(newParams(modelTwo))
|
|
|
+ .addNewRoutingEntry(NODE_ID)
|
|
|
+ .updateExistingRoutingEntry(
|
|
|
+ NODE_ID,
|
|
|
+ new RoutingStateAndReason(
|
|
|
+ randomFrom(RoutingState.STARTED, RoutingState.STARTING),
|
|
|
+ randomAlphaOfLength(10)
|
|
|
+ )
|
|
|
+ )
|
|
|
+ ).addNewAllocation(
|
|
|
+ previouslyUsedModel,
|
|
|
+ TrainedModelAllocation.Builder
|
|
|
+ .empty(newParams(modelTwo))
|
|
|
+ .addNewRoutingEntry(NODE_ID)
|
|
|
+ .updateExistingRoutingEntry(
|
|
|
+ NODE_ID,
|
|
|
+ new RoutingStateAndReason(
|
|
|
+ randomFrom(RoutingState.STOPPED, RoutingState.FAILED, RoutingState.STOPPING),
|
|
|
+ randomAlphaOfLength(10)
|
|
|
+ )
|
|
|
+ )
|
|
|
)
|
|
|
.addNewAllocation(
|
|
|
notUsedModel,
|
|
@@ -425,6 +482,20 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
|
|
|
}).when(deploymentManager).startDeployment(any(), any());
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings({ "rawtypes", "unchecked" })
|
|
|
+ private void withSearchingLoadFailure(String modelId) {
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) invocationOnMock.getArguments()[0];
|
|
|
+ ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
|
|
|
+ if (task.getModelId().equals(modelId)) {
|
|
|
+ listener.onFailure(new SearchPhaseExecutionException("all shards failed", "foo", ShardSearchFailure.EMPTY_ARRAY));
|
|
|
+ } else {
|
|
|
+ listener.onResponse(task);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }).when(deploymentManager).startDeployment(any(), any());
|
|
|
+ }
|
|
|
+
|
|
|
private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) {
|
|
|
return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1);
|
|
|
}
|