浏览代码

[ML] handle recovery from node removal better in model deployments (#78764)

* [ML] handle recovery from node removal better in model deployments

Node deployments previously did not handle node failures very well.

Now, if a deployment is seen in state, but a local task is not found,
the model will attempt to the loaded again.

Additionally, if the model loading fails due to a searching failure (shards or otherwise),
it will be retried at the next loading attempt.
Benjamin Trent 4 年之前
父节点
当前提交
5428c3c415

+ 9 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java

@@ -24,12 +24,21 @@ public enum RoutingState implements MemoryTrackedTaskState {
     }
 
     /**
+     * @param candidates one or more candidate states
      * @return {@code true} if state matches none of the given {@code candidates}
      */
     public boolean isNoneOf(RoutingState... candidates) {
         return Arrays.stream(candidates).noneMatch(candidate -> this == candidate);
     }
 
+    /**
+     * @param candidates one or more candidate states
+     * @return {@code true} if state matches one of the given {@code candidates}
+     */
+    public boolean isAnyOf(RoutingState... candidates) {
+        return Arrays.stream(candidates).anyMatch(candidate -> this == candidate);
+    }
+
     @Override
     public String toString() {
         return name().toLowerCase(Locale.ROOT);

+ 8 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.ml.action;
 
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.TaskOperationFailure;
@@ -15,6 +16,7 @@ import org.elasticsearch.action.support.tasks.TransportTasksAction;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
@@ -72,6 +74,12 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
             throw org.elasticsearch.ExceptionsHelper.convertToElastic(taskOperationFailures.get(0).getCause());
         } else if (failedNodeExceptions.isEmpty() == false) {
             throw org.elasticsearch.ExceptionsHelper.convertToElastic(failedNodeExceptions.get(0));
+        } else if (tasks.isEmpty()) {
+            throw new ElasticsearchStatusException(
+                "[{}] unable to find deployment task for inference please stop and start the deployment or try again momentarily",
+                RestStatus.NOT_FOUND,
+                request.getDeploymentId()
+            );
         } else {
             return tasks.get(0);
         }

+ 17 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

@@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.SearchPhaseExecutionException;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.cluster.ClusterChangedEvent;
@@ -40,6 +41,7 @@ import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
 import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
 import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
 
+import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Deque;
 import java.util.List;
@@ -160,6 +162,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
         TrainedModelDeploymentTask loadingTask;
         logger.trace("attempting to load all currently queued models");
         // NOTE: As soon as this method exits, the timer for the scheduler starts ticking
+        Deque<TrainedModelDeploymentTask> loadingToRetry = new ArrayDeque<>();
         while ((loadingTask = loadingModels.poll()) != null) {
             final String modelId = loadingTask.getModelId();
             if (loadingTask.isStopped()) {
@@ -174,17 +177,24 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
             }
             logger.trace(() -> new ParameterizedMessage("[{}] attempting to load model", modelId));
             final PlainActionFuture<TrainedModelDeploymentTask> listener = new PlainActionFuture<>();
-            deploymentManager.startDeployment(loadingTask, listener);
             try {
+                deploymentManager.startDeployment(loadingTask, listener);
                 // This needs to be synchronous here in the utility thread to keep queueing order
                 TrainedModelDeploymentTask deployedTask = listener.actionGet();
                 // kicks off asynchronous cluster state update
                 handleLoadSuccess(deployedTask);
             } catch (Exception ex) {
-                // kicks off asynchronous cluster state update
-                handleLoadFailure(loadingTask, ex);
+                if (ExceptionsHelper.unwrapCause(ex) instanceof ResourceNotFoundException) {
+                    handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(loadingTask.getModelId()));
+                } else if (ExceptionsHelper.unwrapCause(ex) instanceof SearchPhaseExecutionException) {
+                    // A search phase execution failure should be retried, push task back to the queue
+                    loadingToRetry.add(loadingTask);
+                } else {
+                    handleLoadFailure(loadingTask, ex);
+                }
             }
         }
+        loadingModels.addAll(loadingToRetry);
     }
 
     public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason) {
@@ -270,9 +280,11 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
                 RoutingStateAndReason routingStateAndReason = trainedModelAllocation.getNodeRoutingTable().get(currentNode);
                 // Add new models to start loading
                 if (routingStateAndReason != null
-                    // periodic retries should be handled in a separate thread think
-                    && routingStateAndReason.getState().equals(RoutingState.STARTING)
+                    // periodic retries of `failed` should be handled in a separate process
+                    && routingStateAndReason.getState().isAnyOf(RoutingState.STARTING, RoutingState.STARTED)
                     // This means we don't already have a task and should attempt creating one and starting the model loading
+                    // If we don't have a task but are STARTED, this means the cluster state had a started allocation,
+                    //   the node crashed and then started again
                     && modelIdToTask.containsKey(trainedModelAllocation.getTaskParams().getModelId()) == false
                     // If we are in reset mode, don't start loading a new model on this node.
                     && isResetMode == false) {

+ 72 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java

@@ -10,6 +10,8 @@ package org.elasticsearch.xpack.ml.inference.allocation;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.SearchPhaseExecutionException;
+import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.cluster.ClusterChangedEvent;
 import org.elasticsearch.cluster.ClusterName;
@@ -30,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
 import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
@@ -129,6 +132,38 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
         verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService);
     }
 
+    public void testLoadQueuedModelsWhenFailureIsRetried() {
+        String modelToLoad = "loading-model";
+        String failedModelToLoad = "failed-search-loading-model";
+        withSearchingLoadFailure(failedModelToLoad);
+        TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService();
+
+        trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad));
+        trainedModelAllocationNodeService.prepareModelToLoad(newParams(failedModelToLoad));
+
+        trainedModelAllocationNodeService.loadQueuedModels();
+
+        trainedModelAllocationNodeService.loadQueuedModels();
+
+        ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
+        ArgumentCaptor<UpdateTrainedModelAllocationStateAction.Request> requestCapture = ArgumentCaptor.forClass(
+            UpdateTrainedModelAllocationStateAction.Request.class
+        );
+        verify(deploymentManager, times(3)).startDeployment(startTaskCapture.capture(), any());
+        // Only the successful one is notifying, the failed one keeps retrying but not notifying as it is never successful
+        verify(trainedModelAllocationService, times(1)).updateModelAllocationState(requestCapture.capture(), any());
+
+        assertThat(startTaskCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad));
+        assertThat(requestCapture.getAllValues().get(0).getModelId(), equalTo(modelToLoad));
+        assertThat(requestCapture.getAllValues().get(0).getNodeId(), equalTo(NODE_ID));
+        assertThat(requestCapture.getAllValues().get(0).getRoutingState().getState(), equalTo(RoutingState.STARTED));
+
+        assertThat(startTaskCapture.getAllValues().get(1).getModelId(), equalTo(failedModelToLoad));
+        assertThat(startTaskCapture.getAllValues().get(2).getModelId(), equalTo(failedModelToLoad));
+
+        verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService);
+    }
+
     public void testLoadQueuedModelsWhenStopped() {
         TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService();
 
@@ -304,6 +339,7 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
         String modelOne = "model-1";
         String modelTwo = "model-2";
         String notUsedModel = "model-3";
+        String previouslyUsedModel = "model-4";
         ClusterChangedEvent event = new ClusterChangedEvent(
             "testClusterChanged",
             ClusterState.builder(new ClusterName("testClusterChanged"))
@@ -319,7 +355,28 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
                                 )
                                 .addNewAllocation(
                                     modelTwo,
-                                    TrainedModelAllocation.Builder.empty(newParams(modelTwo)).addNewRoutingEntry(NODE_ID)
+                                    TrainedModelAllocation.Builder
+                                        .empty(newParams(modelTwo))
+                                        .addNewRoutingEntry(NODE_ID)
+                                        .updateExistingRoutingEntry(
+                                            NODE_ID,
+                                            new RoutingStateAndReason(
+                                                randomFrom(RoutingState.STARTED, RoutingState.STARTING),
+                                                randomAlphaOfLength(10)
+                                            )
+                                        )
+                                ).addNewAllocation(
+                                    previouslyUsedModel,
+                                    TrainedModelAllocation.Builder
+                                        .empty(newParams(modelTwo))
+                                        .addNewRoutingEntry(NODE_ID)
+                                        .updateExistingRoutingEntry(
+                                            NODE_ID,
+                                            new RoutingStateAndReason(
+                                                randomFrom(RoutingState.STOPPED, RoutingState.FAILED, RoutingState.STOPPING),
+                                                randomAlphaOfLength(10)
+                                            )
+                                        )
                                 )
                                 .addNewAllocation(
                                     notUsedModel,
@@ -425,6 +482,20 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
         }).when(deploymentManager).startDeployment(any(), any());
     }
 
+    @SuppressWarnings({ "rawtypes", "unchecked" })
+    private void withSearchingLoadFailure(String modelId) {
+        doAnswer(invocationOnMock -> {
+            TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) invocationOnMock.getArguments()[0];
+            ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];
+            if (task.getModelId().equals(modelId)) {
+                listener.onFailure(new SearchPhaseExecutionException("all shards failed", "foo", ShardSearchFailure.EMPTY_ARRAY));
+            } else {
+                listener.onResponse(task);
+            }
+            return null;
+        }).when(deploymentManager).startDeployment(any(), any());
+    }
+
     private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) {
         return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1);
     }