|
@@ -30,7 +30,6 @@ import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.tasks.TaskAwareRequest;
|
|
|
import org.elasticsearch.tasks.TaskCancelledException;
|
|
|
import org.elasticsearch.tasks.TaskId;
|
|
|
-import org.elasticsearch.tasks.TaskManager;
|
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
|
import org.elasticsearch.xpack.core.common.notifications.Level;
|
|
@@ -42,6 +41,9 @@ import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPack
|
|
|
import java.io.IOException;
|
|
|
import java.net.MalformedURLException;
|
|
|
import java.net.URISyntaxException;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
|
@@ -49,7 +51,6 @@ import static org.elasticsearch.core.Strings.format;
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
|
import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION;
|
|
|
import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE;
|
|
|
-import static org.elasticsearch.xpack.core.ml.MlTasks.downloadModelTaskDescription;
|
|
|
|
|
|
public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<Request, AcknowledgedResponse> {
|
|
|
|
|
@@ -57,6 +58,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
|
|
|
|
|
|
private final Client client;
|
|
|
private final CircuitBreakerService circuitBreakerService;
|
|
|
+ final Map<String, List<DownloadTaskRemovedListener>> taskRemovedListenersByModelId;
|
|
|
|
|
|
@Inject
|
|
|
public TransportLoadTrainedModelPackage(
|
|
@@ -81,6 +83,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
|
|
|
);
|
|
|
this.client = new OriginSettingClient(client, ML_ORIGIN);
|
|
|
this.circuitBreakerService = circuitBreakerService;
|
|
|
+ taskRemovedListenersByModelId = new HashMap<>();
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -91,6 +94,12 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
|
|
|
@Override
|
|
|
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener)
|
|
|
throws Exception {
|
|
|
+ if (handleDownloadInProgress(request.getModelId(), request.isWaitForCompletion(), listener)) {
|
|
|
+ logger.debug("Existing download of model [{}] in progress", request.getModelId());
|
|
|
+ // download in progress, nothing to do
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
ModelDownloadTask downloadTask = createDownloadTask(request);
|
|
|
|
|
|
try {
|
|
@@ -107,7 +116,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
|
|
|
|
|
|
var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.<AcknowledgedResponse>noop();
|
|
|
|
|
|
- importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask);
|
|
|
+ importModel(client, () -> unregisterTask(downloadTask), request, modelImporter, downloadTask, downloadCompleteListener);
|
|
|
} catch (Exception e) {
|
|
|
taskManager.unregister(downloadTask);
|
|
|
listener.onFailure(e);
|
|
@@ -124,22 +133,91 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
|
|
|
return new ParentTaskAssigningClient(client, parentTaskId);
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Look for a current download task of the model and optionally wait
|
|
|
+ * for that task to complete if there is one.
|
|
|
+ * synchronized with {@code unregisterTask} to prevent the task being
|
|
|
+ * removed before the remove listener is added.
|
|
|
+ * @param modelId Model being downloaded
|
|
|
+ * @param isWaitForCompletion Wait until the download completes before
|
|
|
+ * calling the listener
|
|
|
+ * @param listener Model download listener
|
|
|
+ * @return True if a download task is in progress
|
|
|
+ */
|
|
|
+ synchronized boolean handleDownloadInProgress(
|
|
|
+ String modelId,
|
|
|
+ boolean isWaitForCompletion,
|
|
|
+ ActionListener<AcknowledgedResponse> listener
|
|
|
+ ) {
|
|
|
+ var description = ModelDownloadTask.taskDescription(modelId);
|
|
|
+ var tasks = taskManager.getCancellableTasks().values();
|
|
|
+
|
|
|
+ ModelDownloadTask inProgress = null;
|
|
|
+ for (var task : tasks) {
|
|
|
+ if (description.equals(task.getDescription()) && task instanceof ModelDownloadTask downloadTask) {
|
|
|
+ inProgress = downloadTask;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (inProgress != null) {
|
|
|
+ if (isWaitForCompletion == false) {
|
|
|
+ // Not waiting for the download to complete, it is enough that the download is in progress
|
|
|
+ // Respond now not when the download completes
|
|
|
+ listener.onResponse(AcknowledgedResponse.TRUE);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ // Otherwise register a task removed listener which is called
|
|
|
+ // once the tasks is complete and unregistered
|
|
|
+ var tracker = new DownloadTaskRemovedListener(inProgress, listener);
|
|
|
+ taskRemovedListenersByModelId.computeIfAbsent(modelId, s -> new ArrayList<>()).add(tracker);
|
|
|
+ taskManager.registerRemovedTaskListener(tracker);
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Unregister the completed task triggering any remove task listeners.
|
|
|
+ * This method is synchronized to prevent the task being removed while
|
|
|
+ * {@code waitForExistingDownload} is in progress.
|
|
|
+ * @param task The completed task
|
|
|
+ */
|
|
|
+ synchronized void unregisterTask(ModelDownloadTask task) {
|
|
|
+ taskManager.unregister(task); // unregister will call the on remove function
|
|
|
+
|
|
|
+ var trackers = taskRemovedListenersByModelId.remove(task.getModelId());
|
|
|
+ if (trackers != null) {
|
|
|
+ for (var tracker : trackers) {
|
|
|
+ taskManager.unregisterRemovedTaskListener(tracker);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* This is package scope so that we can test the logic directly.
|
|
|
- * This should only be called from the masterOperation method and the tests
|
|
|
+ * This should only be called from the masterOperation method and the tests.
|
|
|
+ * This method is static for testing.
|
|
|
*
|
|
|
* @param auditClient a client which should only be used to send audit notifications. This client cannot be associated with the passed
|
|
|
* in task, that way when the task is cancelled the notification requests can
|
|
|
* still be performed. If it is associated with the task (i.e. via ParentTaskAssigningClient),
|
|
|
* then the requests will throw a TaskCancelledException.
|
|
|
+ * @param unregisterTaskFn Runnable to unregister the task. Because this is a static function
|
|
|
+ * a lambda is used rather than the instance method.
|
|
|
+ * @param request The download request
|
|
|
+ * @param modelImporter The importer
|
|
|
+ * @param task Download task
|
|
|
+ * @param listener Listener
|
|
|
*/
|
|
|
static void importModel(
|
|
|
Client auditClient,
|
|
|
- TaskManager taskManager,
|
|
|
+ Runnable unregisterTaskFn,
|
|
|
Request request,
|
|
|
ModelImporter modelImporter,
|
|
|
- ActionListener<AcknowledgedResponse> listener,
|
|
|
- Task task
|
|
|
+ ModelDownloadTask task,
|
|
|
+ ActionListener<AcknowledgedResponse> listener
|
|
|
) {
|
|
|
final String modelId = request.getModelId();
|
|
|
final long relativeStartNanos = System.nanoTime();
|
|
@@ -155,9 +233,12 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
|
|
|
Level.INFO
|
|
|
);
|
|
|
listener.onResponse(AcknowledgedResponse.TRUE);
|
|
|
- }, exception -> listener.onFailure(processException(auditClient, modelId, exception)));
|
|
|
+ }, exception -> {
|
|
|
+ task.setTaskException(exception);
|
|
|
+ listener.onFailure(processException(auditClient, modelId, exception));
|
|
|
+ });
|
|
|
|
|
|
- modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task)));
|
|
|
+ modelImporter.doImport(ActionListener.runAfter(finishListener, unregisterTaskFn));
|
|
|
}
|
|
|
|
|
|
static Exception processException(Client auditClient, String modelId, Exception e) {
|
|
@@ -197,14 +278,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
|
|
|
|
|
|
@Override
|
|
|
public ModelDownloadTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
|
|
|
- return new ModelDownloadTask(
|
|
|
- id,
|
|
|
- type,
|
|
|
- action,
|
|
|
- downloadModelTaskDescription(request.getModelId()),
|
|
|
- parentTaskId,
|
|
|
- headers
|
|
|
- );
|
|
|
+ return new ModelDownloadTask(id, type, action, request.getModelId(), parentTaskId, headers);
|
|
|
}
|
|
|
}, false);
|
|
|
}
|