Browse Source

[ML] Protect against multiple concurrent downloads of the same model (#116869) (#117007)

Check for current downloading tasks in the download action.

# Conflicts:
#	x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java
David Kyle 11 months ago
parent
commit
4fac584981
12 changed files with 295 additions and 142 deletions
  1. 39 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java
  2. 6 1
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java
  3. 5 31
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java
  4. 5 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java
  5. 35 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java
  6. 0 32
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java
  7. 0 35
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java
  8. 29 0
      x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/DownloadTaskRemovedListener.java
  9. 21 2
      x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java
  10. 91 17
      x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java
  11. 62 20
      x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java
  12. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

+ 39 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

@@ -7,6 +7,9 @@
 
 package org.elasticsearch.xpack.inference;
 
+import org.elasticsearch.client.Response;
+import org.elasticsearch.client.ResponseListener;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
@@ -15,9 +18,12 @@ import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CountDownLatch;
 
+import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.oneOf;
@@ -100,4 +106,37 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest {
             Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32))
         );
     }
+
+    public void testMultipleInferencesTriggeringDownloadAndDeploy() throws InterruptedException {
+        int numParallelRequests = 4;
+        var latch = new CountDownLatch(numParallelRequests);
+        var errors = new ArrayList<Exception>();
+
+        var listener = new ResponseListener() {
+            @Override
+            public void onSuccess(Response response) {
+                latch.countDown();
+            }
+
+            @Override
+            public void onFailure(Exception exception) {
+                errors.add(exception);
+                latch.countDown();
+            }
+        };
+
+        var inputs = List.of("Hello World", "Goodnight moon");
+        var queryParams = Map.of("timeout", "120s");
+        for (int i = 0; i < numParallelRequests; i++) {
+            var request = createInferenceRequest(
+                Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID),
+                inputs,
+                queryParams
+            );
+            client().performRequestAsync(request, listener);
+        }
+
+        latch.await();
+        assertThat(errors.toString(), errors, empty());
+    }
 }

+ 6 - 1
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -373,12 +373,17 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         return inferInternal(endpoint, input, queryParameters);
     }
 
-    private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
+    protected Request createInferenceRequest(String endpoint, List<String> input, Map<String, String> queryParameters) {
         var request = new Request("POST", endpoint);
         request.setJsonEntity(jsonBody(input));
         if (queryParameters.isEmpty() == false) {
             request.addParameters(queryParameters);
         }
+        return request;
+    }
+
+    private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
+        var request = createInferenceRequest(endpoint, input, queryParameters);
         var response = client().performRequest(request);
         assertOkOrCreated(response);
         return entityAsMap(response);

+ 5 - 31
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java

@@ -7,14 +7,9 @@
 
 package org.elasticsearch.xpack.inference.services.elasticsearch;
 
-import org.elasticsearch.ResourceNotFoundException;
-import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.inference.ChunkingSettings;
-import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
-import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 public class CustomElandModel extends ElasticsearchInternalModel {
 
@@ -39,31 +34,10 @@ public class CustomElandModel extends ElasticsearchInternalModel {
     }
 
     @Override
-    public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
-        Model model,
-        ActionListener<Boolean> listener
-    ) {
-
-        return new ActionListener<>() {
-            @Override
-            public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
-                listener.onResponse(Boolean.TRUE);
-            }
-
-            @Override
-            public void onFailure(Exception e) {
-                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
-                    listener.onFailure(
-                        new ResourceNotFoundException(
-                            "Could not start the inference as the custom eland model [{0}] for this platform cannot be found."
-                                + " Custom models need to be loaded into the cluster with eland before they can be started.",
-                            internalServiceSettings.modelId()
-                        )
-                    );
-                    return;
-                }
-                listener.onFailure(e);
-            }
-        };
+    protected String modelNotFoundErrorMessage(String modelId) {
+        return "Could not deploy model ["
+            + modelId
+            + "] as the model cannot be found."
+            + " Custom models need to be loaded into the cluster with Eland before they can be started.";
     }
 }

+ 5 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java

@@ -36,6 +36,11 @@ public class ElasticDeployedModel extends ElasticsearchInternalModel {
         throw new IllegalStateException("cannot start model that uses an existing deployment");
     }
 
+    @Override
+    protected String modelNotFoundErrorMessage(String modelId) {
+        throw new IllegalStateException("cannot start model [" + modelId + "] that uses an existing deployment");
+    }
+
     @Override
     public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
         Model model,

+ 35 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

@@ -7,6 +7,9 @@
 
 package org.elasticsearch.xpack.inference.services.elasticsearch;
 
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.ResourceAlreadyExistsException;
+import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.core.TimeValue;
@@ -15,8 +18,10 @@ import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED;
 
@@ -79,10 +84,38 @@ public abstract class ElasticsearchInternalModel extends Model {
         return startRequest;
     }
 
-    public abstract ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
+    public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
         Model model,
         ActionListener<Boolean> listener
-    );
+    ) {
+        return new ActionListener<>() {
+            @Override
+            public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
+                listener.onResponse(Boolean.TRUE);
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                var cause = ExceptionsHelper.unwrapCause(e);
+                if (cause instanceof ResourceNotFoundException) {
+                    listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId())));
+                    return;
+                } else if (cause instanceof ElasticsearchStatusException statusException) {
+                    if (statusException.status() == RestStatus.CONFLICT
+                        && statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
+                        // Deployment is already started
+                        listener.onResponse(Boolean.TRUE);
+                    }
+                    return;
+                }
+                listener.onFailure(e);
+            }
+        };
+    }
+
+    protected String modelNotFoundErrorMessage(String modelId) {
+        return "Could not deploy model [" + modelId + "] as the model cannot be found.";
+    }
 
     public boolean usesExistingDeployment() {
         return internalServiceSettings.getDeploymentId() != null;

+ 0 - 32
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java

@@ -7,13 +7,8 @@
 
 package org.elasticsearch.xpack.inference.services.elasticsearch;
 
-import org.elasticsearch.ResourceNotFoundException;
-import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.inference.ChunkingSettings;
-import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
-import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 public class ElserInternalModel extends ElasticsearchInternalModel {
 
@@ -37,31 +32,4 @@ public class ElserInternalModel extends ElasticsearchInternalModel {
     public ElserMlNodeTaskSettings getTaskSettings() {
         return (ElserMlNodeTaskSettings) super.getTaskSettings();
     }
-
-    @Override
-    public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
-        Model model,
-        ActionListener<Boolean> listener
-    ) {
-        return new ActionListener<>() {
-            @Override
-            public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
-                listener.onResponse(Boolean.TRUE);
-            }
-
-            @Override
-            public void onFailure(Exception e) {
-                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
-                    listener.onFailure(
-                        new ResourceNotFoundException(
-                            "Could not start the ELSER service as the ELSER model for this platform cannot be found."
-                                + " ELSER needs to be downloaded before it can be started."
-                        )
-                    );
-                    return;
-                }
-                listener.onFailure(e);
-            }
-        };
-    }
 }

+ 0 - 35
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java

@@ -7,13 +7,8 @@
 
 package org.elasticsearch.xpack.inference.services.elasticsearch;
 
-import org.elasticsearch.ResourceNotFoundException;
-import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.inference.ChunkingSettings;
-import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
-import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 public class MultilingualE5SmallModel extends ElasticsearchInternalModel {
 
@@ -31,34 +26,4 @@ public class MultilingualE5SmallModel extends ElasticsearchInternalModel {
     public MultilingualE5SmallInternalServiceSettings getServiceSettings() {
         return (MultilingualE5SmallInternalServiceSettings) super.getServiceSettings();
     }
-
-    @Override
-    public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
-        Model model,
-        ActionListener<Boolean> listener
-    ) {
-
-        return new ActionListener<>() {
-            @Override
-            public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
-                listener.onResponse(Boolean.TRUE);
-            }
-
-            @Override
-            public void onFailure(Exception e) {
-                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
-                    listener.onFailure(
-                        new ResourceNotFoundException(
-                            "Could not start the TextEmbeddingService service as the "
-                                + "Multilingual-E5-Small model for this platform cannot be found."
-                                + " Multilingual-E5-Small needs to be downloaded before it can be started"
-                        )
-                    );
-                    return;
-                }
-                listener.onFailure(e);
-            }
-        };
-    }
-
 }

+ 29 - 0
x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/DownloadTaskRemovedListener.java

@@ -0,0 +1,29 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.packageloader.action;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
+import org.elasticsearch.tasks.RemovedTaskListener;
+import org.elasticsearch.tasks.Task;
+
+public record DownloadTaskRemovedListener(ModelDownloadTask trackedTask, ActionListener<AcknowledgedResponse> listener)
+    implements
+        RemovedTaskListener {
+
+    @Override
+    public void onRemoved(Task task) {
+        if (task.getId() == trackedTask.getId()) {
+            if (trackedTask.getTaskException() == null) {
+                listener.onResponse(AcknowledgedResponse.TRUE);
+            } else {
+                listener.onFailure(trackedTask.getTaskException());
+            }
+        }
+    }
+}

+ 21 - 2
x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java

@@ -13,6 +13,7 @@ import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.MlTasks;
 
 import java.io.IOException;
 import java.util.Map;
@@ -51,9 +52,12 @@ public class ModelDownloadTask extends CancellableTask {
     }
 
     private final AtomicReference<DownLoadProgress> downloadProgress = new AtomicReference<>(new DownLoadProgress(0, 0));
+    private final String modelId;
+    private volatile Exception taskException;
 
-    public ModelDownloadTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
-        super(id, type, action, description, parentTaskId, headers);
+    public ModelDownloadTask(long id, String type, String action, String modelId, TaskId parentTaskId, Map<String, String> headers) {
+        super(id, type, action, taskDescription(modelId), parentTaskId, headers);
+        this.modelId = modelId;
     }
 
     void setProgress(int totalParts, int downloadedParts) {
@@ -65,4 +69,19 @@ public class ModelDownloadTask extends CancellableTask {
         return new DownloadStatus(downloadProgress.get());
     }
 
+    public String getModelId() {
+        return modelId;
+    }
+
+    public void setTaskException(Exception exception) {
+        this.taskException = exception;
+    }
+
+    public Exception getTaskException() {
+        return taskException;
+    }
+
+    public static String taskDescription(String modelId) {
+        return MlTasks.downloadModelTaskDescription(modelId);
+    }
 }

+ 91 - 17
x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java

@@ -30,7 +30,6 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskAwareRequest;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.tasks.TaskId;
-import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.common.notifications.Level;
@@ -42,6 +41,9 @@ import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPack
 import java.io.IOException;
 import java.net.MalformedURLException;
 import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 
@@ -49,7 +51,6 @@ import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION;
 import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE;
-import static org.elasticsearch.xpack.core.ml.MlTasks.downloadModelTaskDescription;
 
 public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<Request, AcknowledgedResponse> {
 
@@ -57,6 +58,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
 
     private final Client client;
     private final CircuitBreakerService circuitBreakerService;
+    final Map<String, List<DownloadTaskRemovedListener>> taskRemovedListenersByModelId;
 
     @Inject
     public TransportLoadTrainedModelPackage(
@@ -81,6 +83,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
         );
         this.client = new OriginSettingClient(client, ML_ORIGIN);
         this.circuitBreakerService = circuitBreakerService;
+        taskRemovedListenersByModelId = new HashMap<>();
     }
 
     @Override
@@ -91,6 +94,12 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
     @Override
     protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener)
         throws Exception {
+        if (handleDownloadInProgress(request.getModelId(), request.isWaitForCompletion(), listener)) {
+            logger.debug("Existing download of model [{}] in progress", request.getModelId());
+            // download in progress, nothing to do
+            return;
+        }
+
         ModelDownloadTask downloadTask = createDownloadTask(request);
 
         try {
@@ -107,7 +116,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
 
             var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.<AcknowledgedResponse>noop();
 
-            importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask);
+            importModel(client, () -> unregisterTask(downloadTask), request, modelImporter, downloadTask, downloadCompleteListener);
         } catch (Exception e) {
             taskManager.unregister(downloadTask);
             listener.onFailure(e);
@@ -124,22 +133,91 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
         return new ParentTaskAssigningClient(client, parentTaskId);
     }
 
+    /**
+     * Look for a current download task of the model and optionally wait
+     * for that task to complete if there is one.
+     * synchronized with {@code unregisterTask} to prevent the task being
+     * removed before the remove listener is added.
+     * @param modelId Model being downloaded
+     * @param isWaitForCompletion Wait until the download completes before
+     *                            calling the listener
+     * @param listener Model download listener
+     * @return True if a download task is in progress
+     */
+    synchronized boolean handleDownloadInProgress(
+        String modelId,
+        boolean isWaitForCompletion,
+        ActionListener<AcknowledgedResponse> listener
+    ) {
+        var description = ModelDownloadTask.taskDescription(modelId);
+        var tasks = taskManager.getCancellableTasks().values();
+
+        ModelDownloadTask inProgress = null;
+        for (var task : tasks) {
+            if (description.equals(task.getDescription()) && task instanceof ModelDownloadTask downloadTask) {
+                inProgress = downloadTask;
+                break;
+            }
+        }
+
+        if (inProgress != null) {
+            if (isWaitForCompletion == false) {
+                // Not waiting for the download to complete, it is enough that the download is in progress
+                // Respond now not when the download completes
+                listener.onResponse(AcknowledgedResponse.TRUE);
+                return true;
+            }
+            // Otherwise register a task removed listener which is called
+            // once the tasks is complete and unregistered
+            var tracker = new DownloadTaskRemovedListener(inProgress, listener);
+            taskRemovedListenersByModelId.computeIfAbsent(modelId, s -> new ArrayList<>()).add(tracker);
+            taskManager.registerRemovedTaskListener(tracker);
+            return true;
+        }
+
+        return false;
+    }
+
+    /**
+     * Unregister the completed task triggering any remove task listeners.
+     * This method is synchronized to prevent the task being removed while
+     * {@code waitForExistingDownload} is in progress.
+     * @param task The completed task
+     */
+    synchronized void unregisterTask(ModelDownloadTask task) {
+        taskManager.unregister(task); // unregister will call the on remove function
+
+        var trackers = taskRemovedListenersByModelId.remove(task.getModelId());
+        if (trackers != null) {
+            for (var tracker : trackers) {
+                taskManager.unregisterRemovedTaskListener(tracker);
+            }
+        }
+    }
+
     /**
      * This is package scope so that we can test the logic directly.
-     * This should only be called from the masterOperation method and the tests
+     * This should only be called from the masterOperation method and the tests.
+     * This method is static for testing.
      *
      * @param auditClient a client which should only be used to send audit notifications. This client cannot be associated with the passed
      *                    in task, that way when the task is cancelled the notification requests can
      *                    still be performed. If it is associated with the task (i.e. via ParentTaskAssigningClient),
      *                    then the requests will throw a TaskCancelledException.
+     * @param unregisterTaskFn Runnable to unregister the task. Because this is a static function
+     *                a lambda is used rather than the instance method.
+     * @param request The download request
+     * @param modelImporter The importer
+     * @param task Download task
+     * @param listener Listener
      */
     static void importModel(
         Client auditClient,
-        TaskManager taskManager,
+        Runnable unregisterTaskFn,
         Request request,
         ModelImporter modelImporter,
-        ActionListener<AcknowledgedResponse> listener,
-        Task task
+        ModelDownloadTask task,
+        ActionListener<AcknowledgedResponse> listener
     ) {
         final String modelId = request.getModelId();
         final long relativeStartNanos = System.nanoTime();
@@ -155,9 +233,12 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
                 Level.INFO
             );
             listener.onResponse(AcknowledgedResponse.TRUE);
-        }, exception -> listener.onFailure(processException(auditClient, modelId, exception)));
+        }, exception -> {
+            task.setTaskException(exception);
+            listener.onFailure(processException(auditClient, modelId, exception));
+        });
 
-        modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task)));
+        modelImporter.doImport(ActionListener.runAfter(finishListener, unregisterTaskFn));
     }
 
     static Exception processException(Client auditClient, String modelId, Exception e) {
@@ -197,14 +278,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
 
                 @Override
                 public ModelDownloadTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
-                    return new ModelDownloadTask(
-                        id,
-                        type,
-                        action,
-                        downloadModelTaskDescription(request.getModelId()),
-                        parentTaskId,
-                        headers
-                    );
+                    return new ModelDownloadTask(id, type, action, request.getModelId(), parentTaskId, headers);
                 }
             }, false);
         }

+ 62 - 20
x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java

@@ -10,13 +10,19 @@ package org.elasticsearch.xpack.ml.packageloader.action;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.common.notifications.Level;
 import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
@@ -27,9 +33,13 @@ import org.mockito.ArgumentCaptor;
 import java.io.IOException;
 import java.net.MalformedURLException;
 import java.net.URISyntaxException;
+import java.util.Map;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION;
+import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE;
+import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.core.Is.is;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
@@ -37,6 +47,7 @@ import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 public class TransportLoadTrainedModelPackageTests extends ESTestCase {
     private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]";
@@ -44,17 +55,10 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
     public void testSendsFinishedUploadNotification() {
         var uploader = createUploader(null);
         var taskManager = mock(TaskManager.class);
-        var task = mock(Task.class);
+        var task = mock(ModelDownloadTask.class);
         var client = mock(Client.class);
 
-        TransportLoadTrainedModelPackage.importModel(
-            client,
-            taskManager,
-            createRequestWithWaiting(),
-            uploader,
-            ActionListener.noop(),
-            task
-        );
+        TransportLoadTrainedModelPackage.importModel(client, () -> {}, createRequestWithWaiting(), uploader, task, ActionListener.noop());
 
         var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class);
         // 2 notifications- the start and finish messages
@@ -108,32 +112,63 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
     public void testCallsOnResponseWithAcknowledgedResponse() throws Exception {
         var client = mock(Client.class);
         var taskManager = mock(TaskManager.class);
-        var task = mock(Task.class);
+        var task = mock(ModelDownloadTask.class);
         ModelImporter uploader = createUploader(null);
 
         var responseRef = new AtomicReference<AcknowledgedResponse>();
         var listener = ActionListener.wrap(responseRef::set, e -> fail("received an exception: " + e.getMessage()));
 
-        TransportLoadTrainedModelPackage.importModel(client, taskManager, createRequestWithWaiting(), uploader, listener, task);
+        TransportLoadTrainedModelPackage.importModel(client, () -> {}, createRequestWithWaiting(), uploader, task, listener);
         assertThat(responseRef.get(), is(AcknowledgedResponse.TRUE));
     }
 
     public void testDoesNotCallListenerWhenNotWaitingForCompletion() {
         var uploader = mock(ModelImporter.class);
         var client = mock(Client.class);
-        var taskManager = mock(TaskManager.class);
-        var task = mock(Task.class);
-
+        var task = mock(ModelDownloadTask.class);
         TransportLoadTrainedModelPackage.importModel(
             client,
-            taskManager,
+            () -> {},
             createRequestWithoutWaiting(),
             uploader,
-            ActionListener.running(ESTestCase::fail),
-            task
+            task,
+            ActionListener.running(ESTestCase::fail)
         );
     }
 
+    public void testWaitForExistingDownload() {
+        var taskManager = mock(TaskManager.class);
+        var modelId = "foo";
+        var task = new ModelDownloadTask(1L, MODEL_IMPORT_TASK_TYPE, MODEL_IMPORT_TASK_ACTION, modelId, new TaskId("node", 1L), Map.of());
+        when(taskManager.getCancellableTasks()).thenReturn(Map.of(1L, task));
+
+        var transportService = mock(TransportService.class);
+        when(transportService.getTaskManager()).thenReturn(taskManager);
+
+        var action = new TransportLoadTrainedModelPackage(
+            transportService,
+            mock(ClusterService.class),
+            mock(ThreadPool.class),
+            mock(ActionFilters.class),
+            mock(IndexNameExpressionResolver.class),
+            mock(Client.class),
+            mock(CircuitBreakerService.class)
+        );
+
+        assertTrue(action.handleDownloadInProgress(modelId, true, ActionListener.noop()));
+        verify(taskManager).registerRemovedTaskListener(any());
+        assertThat(action.taskRemovedListenersByModelId.entrySet(), hasSize(1));
+        assertThat(action.taskRemovedListenersByModelId.get(modelId), hasSize(1));
+
+        // With wait for completion == false no new removed listener will be added
+        assertTrue(action.handleDownloadInProgress(modelId, false, ActionListener.noop()));
+        verify(taskManager, times(1)).registerRemovedTaskListener(any());
+        assertThat(action.taskRemovedListenersByModelId.entrySet(), hasSize(1));
+        assertThat(action.taskRemovedListenersByModelId.get(modelId), hasSize(1));
+
+        assertFalse(action.handleDownloadInProgress("no-task-for-this-one", randomBoolean(), ActionListener.noop()));
+    }
+
     private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception {
         var esStatusException = new ElasticsearchStatusException(message, status, exception);
 
@@ -152,7 +187,7 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
     ) throws Exception {
         var client = mock(Client.class);
         var taskManager = mock(TaskManager.class);
-        var task = mock(Task.class);
+        var task = mock(ModelDownloadTask.class);
         ModelImporter uploader = createUploader(thrownException);
 
         var failureRef = new AtomicReference<Exception>();
@@ -160,7 +195,14 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
             (AcknowledgedResponse response) -> { fail("received a acknowledged response: " + response.toString()); },
             failureRef::set
         );
-        TransportLoadTrainedModelPackage.importModel(client, taskManager, createRequestWithWaiting(), uploader, listener, task);
+        TransportLoadTrainedModelPackage.importModel(
+            client,
+            () -> taskManager.unregister(task),
+            createRequestWithWaiting(),
+            uploader,
+            task,
+            listener
+        );
 
         var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class);
         // 2 notifications- the starting message and the failure

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

@@ -190,11 +190,11 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
                     () -> "[" + request.getDeploymentId() + "] creating new assignment for model [" + request.getModelId() + "] failed",
                     e
                 );
-                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) {
+                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException resourceAlreadyExistsException) {
                     e = new ElasticsearchStatusException(
                         "Cannot start deployment [{}] because it has already been started",
                         RestStatus.CONFLICT,
-                        e,
+                        resourceAlreadyExistsException,
                         request.getDeploymentId()
                     );
                 }