|
@@ -10,13 +10,19 @@ package org.elasticsearch.xpack.ml.packageloader.action;
|
|
import org.elasticsearch.ElasticsearchException;
|
|
import org.elasticsearch.ElasticsearchException;
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
import org.elasticsearch.action.ActionListener;
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
|
+import org.elasticsearch.action.support.ActionFilters;
|
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
|
import org.elasticsearch.client.internal.Client;
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
|
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
|
|
|
+import org.elasticsearch.cluster.service.ClusterService;
|
|
|
|
+import org.elasticsearch.indices.breaker.CircuitBreakerService;
|
|
import org.elasticsearch.rest.RestStatus;
|
|
import org.elasticsearch.rest.RestStatus;
|
|
-import org.elasticsearch.tasks.Task;
|
|
|
|
import org.elasticsearch.tasks.TaskCancelledException;
|
|
import org.elasticsearch.tasks.TaskCancelledException;
|
|
|
|
+import org.elasticsearch.tasks.TaskId;
|
|
import org.elasticsearch.tasks.TaskManager;
|
|
import org.elasticsearch.tasks.TaskManager;
|
|
import org.elasticsearch.test.ESTestCase;
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
|
+import org.elasticsearch.threadpool.ThreadPool;
|
|
|
|
+import org.elasticsearch.transport.TransportService;
|
|
import org.elasticsearch.xpack.core.common.notifications.Level;
|
|
import org.elasticsearch.xpack.core.common.notifications.Level;
|
|
import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction;
|
|
import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction;
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
|
|
@@ -27,9 +33,13 @@ import org.mockito.ArgumentCaptor;
|
|
import java.io.IOException;
|
|
import java.io.IOException;
|
|
import java.net.MalformedURLException;
|
|
import java.net.MalformedURLException;
|
|
import java.net.URISyntaxException;
|
|
import java.net.URISyntaxException;
|
|
|
|
+import java.util.Map;
|
|
import java.util.concurrent.atomic.AtomicReference;
|
|
import java.util.concurrent.atomic.AtomicReference;
|
|
|
|
|
|
import static org.elasticsearch.core.Strings.format;
|
|
import static org.elasticsearch.core.Strings.format;
|
|
|
|
+import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION;
|
|
|
|
+import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE;
|
|
|
|
+import static org.hamcrest.Matchers.hasSize;
|
|
import static org.hamcrest.core.Is.is;
|
|
import static org.hamcrest.core.Is.is;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.eq;
|
|
import static org.mockito.ArgumentMatchers.eq;
|
|
@@ -37,6 +47,7 @@ import static org.mockito.Mockito.doAnswer;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.times;
|
|
import static org.mockito.Mockito.times;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verify;
|
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
|
|
|
public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]";
|
|
private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]";
|
|
@@ -44,17 +55,10 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
public void testSendsFinishedUploadNotification() {
|
|
public void testSendsFinishedUploadNotification() {
|
|
var uploader = createUploader(null);
|
|
var uploader = createUploader(null);
|
|
var taskManager = mock(TaskManager.class);
|
|
var taskManager = mock(TaskManager.class);
|
|
- var task = mock(Task.class);
|
|
|
|
|
|
+ var task = mock(ModelDownloadTask.class);
|
|
var client = mock(Client.class);
|
|
var client = mock(Client.class);
|
|
|
|
|
|
- TransportLoadTrainedModelPackage.importModel(
|
|
|
|
- client,
|
|
|
|
- taskManager,
|
|
|
|
- createRequestWithWaiting(),
|
|
|
|
- uploader,
|
|
|
|
- ActionListener.noop(),
|
|
|
|
- task
|
|
|
|
- );
|
|
|
|
|
|
+ TransportLoadTrainedModelPackage.importModel(client, () -> {}, createRequestWithWaiting(), uploader, task, ActionListener.noop());
|
|
|
|
|
|
var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class);
|
|
var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class);
|
|
// 2 notifications- the start and finish messages
|
|
// 2 notifications- the start and finish messages
|
|
@@ -108,32 +112,63 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
public void testCallsOnResponseWithAcknowledgedResponse() throws Exception {
|
|
public void testCallsOnResponseWithAcknowledgedResponse() throws Exception {
|
|
var client = mock(Client.class);
|
|
var client = mock(Client.class);
|
|
var taskManager = mock(TaskManager.class);
|
|
var taskManager = mock(TaskManager.class);
|
|
- var task = mock(Task.class);
|
|
|
|
|
|
+ var task = mock(ModelDownloadTask.class);
|
|
ModelImporter uploader = createUploader(null);
|
|
ModelImporter uploader = createUploader(null);
|
|
|
|
|
|
var responseRef = new AtomicReference<AcknowledgedResponse>();
|
|
var responseRef = new AtomicReference<AcknowledgedResponse>();
|
|
var listener = ActionListener.wrap(responseRef::set, e -> fail("received an exception: " + e.getMessage()));
|
|
var listener = ActionListener.wrap(responseRef::set, e -> fail("received an exception: " + e.getMessage()));
|
|
|
|
|
|
- TransportLoadTrainedModelPackage.importModel(client, taskManager, createRequestWithWaiting(), uploader, listener, task);
|
|
|
|
|
|
+ TransportLoadTrainedModelPackage.importModel(client, () -> {}, createRequestWithWaiting(), uploader, task, listener);
|
|
assertThat(responseRef.get(), is(AcknowledgedResponse.TRUE));
|
|
assertThat(responseRef.get(), is(AcknowledgedResponse.TRUE));
|
|
}
|
|
}
|
|
|
|
|
|
public void testDoesNotCallListenerWhenNotWaitingForCompletion() {
|
|
public void testDoesNotCallListenerWhenNotWaitingForCompletion() {
|
|
var uploader = mock(ModelImporter.class);
|
|
var uploader = mock(ModelImporter.class);
|
|
var client = mock(Client.class);
|
|
var client = mock(Client.class);
|
|
- var taskManager = mock(TaskManager.class);
|
|
|
|
- var task = mock(Task.class);
|
|
|
|
-
|
|
|
|
|
|
+ var task = mock(ModelDownloadTask.class);
|
|
TransportLoadTrainedModelPackage.importModel(
|
|
TransportLoadTrainedModelPackage.importModel(
|
|
client,
|
|
client,
|
|
- taskManager,
|
|
|
|
|
|
+ () -> {},
|
|
createRequestWithoutWaiting(),
|
|
createRequestWithoutWaiting(),
|
|
uploader,
|
|
uploader,
|
|
- ActionListener.running(ESTestCase::fail),
|
|
|
|
- task
|
|
|
|
|
|
+ task,
|
|
|
|
+ ActionListener.running(ESTestCase::fail)
|
|
);
|
|
);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ public void testWaitForExistingDownload() {
|
|
|
|
+ var taskManager = mock(TaskManager.class);
|
|
|
|
+ var modelId = "foo";
|
|
|
|
+ var task = new ModelDownloadTask(1L, MODEL_IMPORT_TASK_TYPE, MODEL_IMPORT_TASK_ACTION, modelId, new TaskId("node", 1L), Map.of());
|
|
|
|
+ when(taskManager.getCancellableTasks()).thenReturn(Map.of(1L, task));
|
|
|
|
+
|
|
|
|
+ var transportService = mock(TransportService.class);
|
|
|
|
+ when(transportService.getTaskManager()).thenReturn(taskManager);
|
|
|
|
+
|
|
|
|
+ var action = new TransportLoadTrainedModelPackage(
|
|
|
|
+ transportService,
|
|
|
|
+ mock(ClusterService.class),
|
|
|
|
+ mock(ThreadPool.class),
|
|
|
|
+ mock(ActionFilters.class),
|
|
|
|
+ mock(IndexNameExpressionResolver.class),
|
|
|
|
+ mock(Client.class),
|
|
|
|
+ mock(CircuitBreakerService.class)
|
|
|
|
+ );
|
|
|
|
+
|
|
|
|
+ assertTrue(action.handleDownloadInProgress(modelId, true, ActionListener.noop()));
|
|
|
|
+ verify(taskManager).registerRemovedTaskListener(any());
|
|
|
|
+ assertThat(action.taskRemovedListenersByModelId.entrySet(), hasSize(1));
|
|
|
|
+ assertThat(action.taskRemovedListenersByModelId.get(modelId), hasSize(1));
|
|
|
|
+
|
|
|
|
+ // With wait for completion == false no new removed listener will be added
|
|
|
|
+ assertTrue(action.handleDownloadInProgress(modelId, false, ActionListener.noop()));
|
|
|
|
+ verify(taskManager, times(1)).registerRemovedTaskListener(any());
|
|
|
|
+ assertThat(action.taskRemovedListenersByModelId.entrySet(), hasSize(1));
|
|
|
|
+ assertThat(action.taskRemovedListenersByModelId.get(modelId), hasSize(1));
|
|
|
|
+
|
|
|
|
+ assertFalse(action.handleDownloadInProgress("no-task-for-this-one", randomBoolean(), ActionListener.noop()));
|
|
|
|
+ }
|
|
|
|
+
|
|
private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception {
|
|
private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception {
|
|
var esStatusException = new ElasticsearchStatusException(message, status, exception);
|
|
var esStatusException = new ElasticsearchStatusException(message, status, exception);
|
|
|
|
|
|
@@ -152,7 +187,7 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
) throws Exception {
|
|
) throws Exception {
|
|
var client = mock(Client.class);
|
|
var client = mock(Client.class);
|
|
var taskManager = mock(TaskManager.class);
|
|
var taskManager = mock(TaskManager.class);
|
|
- var task = mock(Task.class);
|
|
|
|
|
|
+ var task = mock(ModelDownloadTask.class);
|
|
ModelImporter uploader = createUploader(thrownException);
|
|
ModelImporter uploader = createUploader(thrownException);
|
|
|
|
|
|
var failureRef = new AtomicReference<Exception>();
|
|
var failureRef = new AtomicReference<Exception>();
|
|
@@ -160,7 +195,14 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
(AcknowledgedResponse response) -> { fail("received a acknowledged response: " + response.toString()); },
|
|
(AcknowledgedResponse response) -> { fail("received a acknowledged response: " + response.toString()); },
|
|
failureRef::set
|
|
failureRef::set
|
|
);
|
|
);
|
|
- TransportLoadTrainedModelPackage.importModel(client, taskManager, createRequestWithWaiting(), uploader, listener, task);
|
|
|
|
|
|
+ TransportLoadTrainedModelPackage.importModel(
|
|
|
|
+ client,
|
|
|
|
+ () -> taskManager.unregister(task),
|
|
|
|
+ createRequestWithWaiting(),
|
|
|
|
+ uploader,
|
|
|
|
+ task,
|
|
|
|
+ listener
|
|
|
|
+ );
|
|
|
|
|
|
var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class);
|
|
var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class);
|
|
// 2 notifications- the starting message and the failure
|
|
// 2 notifications- the starting message and the failure
|