|
@@ -7,14 +7,17 @@
|
|
|
|
|
|
package org.elasticsearch.xpack.ml.packageloader.action;
|
|
|
|
|
|
+import org.elasticsearch.ElasticsearchException;
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
+import org.elasticsearch.tasks.TaskCancelledException;
|
|
|
import org.elasticsearch.tasks.TaskManager;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
+import org.elasticsearch.xpack.core.common.notifications.Level;
|
|
|
import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
|
|
@@ -62,36 +65,44 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
|
|
|
|
public void testSendsErrorNotificationForInternalError() throws URISyntaxException, IOException {
|
|
|
ElasticsearchStatusException exception = new ElasticsearchStatusException("exception", RestStatus.INTERNAL_SERVER_ERROR);
|
|
|
+ String message = format("Model importing failed due to [%s]", exception.toString());
|
|
|
|
|
|
- assertUploadCallsOnFailure(exception, exception.toString());
|
|
|
+ assertUploadCallsOnFailure(exception, message, Level.ERROR);
|
|
|
}
|
|
|
|
|
|
public void testSendsErrorNotificationForMalformedURL() throws URISyntaxException, IOException {
|
|
|
MalformedURLException exception = new MalformedURLException("exception");
|
|
|
String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL", exception.toString());
|
|
|
|
|
|
- assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR);
|
|
|
+ assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR);
|
|
|
}
|
|
|
|
|
|
public void testSendsErrorNotificationForURISyntax() throws URISyntaxException, IOException {
|
|
|
URISyntaxException exception = mock(URISyntaxException.class);
|
|
|
String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL syntax", exception.toString());
|
|
|
|
|
|
- assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR);
|
|
|
+ assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR);
|
|
|
}
|
|
|
|
|
|
public void testSendsErrorNotificationForIOException() throws URISyntaxException, IOException {
|
|
|
IOException exception = mock(IOException.class);
|
|
|
String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an IOException", exception.toString());
|
|
|
|
|
|
- assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE);
|
|
|
+ assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE, Level.ERROR);
|
|
|
}
|
|
|
|
|
|
public void testSendsErrorNotificationForException() throws URISyntaxException, IOException {
|
|
|
RuntimeException exception = mock(RuntimeException.class);
|
|
|
String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an Exception", exception.toString());
|
|
|
|
|
|
- assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR);
|
|
|
+ assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testSendsWarningNotificationForTaskCancelledException() throws URISyntaxException, IOException {
|
|
|
+ TaskCancelledException exception = new TaskCancelledException("cancelled");
|
|
|
+ String message = format("Model importing failed due to [%s]", exception.toString());
|
|
|
+
|
|
|
+ assertUploadCallsOnFailure(exception, message, Level.WARNING);
|
|
|
}
|
|
|
|
|
|
public void testCallsOnResponseWithAcknowledgedResponse() throws URISyntaxException, IOException {
|
|
@@ -123,18 +134,24 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status) throws URISyntaxException, IOException {
|
|
|
+ private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws URISyntaxException,
|
|
|
+ IOException {
|
|
|
var esStatusException = new ElasticsearchStatusException(message, status, exception);
|
|
|
|
|
|
- assertNotificationAndOnFailure(exception, esStatusException, message);
|
|
|
+ assertNotificationAndOnFailure(exception, esStatusException, message, level);
|
|
|
}
|
|
|
|
|
|
- private void assertUploadCallsOnFailure(ElasticsearchStatusException exception, String message) throws URISyntaxException, IOException {
|
|
|
- assertNotificationAndOnFailure(exception, exception, message);
|
|
|
+ private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws URISyntaxException,
|
|
|
+ IOException {
|
|
|
+ assertNotificationAndOnFailure(exception, exception, message, level);
|
|
|
}
|
|
|
|
|
|
- private void assertNotificationAndOnFailure(Exception thrownException, ElasticsearchStatusException onFailureException, String message)
|
|
|
- throws URISyntaxException, IOException {
|
|
|
+ private void assertNotificationAndOnFailure(
|
|
|
+ Exception thrownException,
|
|
|
+ ElasticsearchException onFailureException,
|
|
|
+ String message,
|
|
|
+ Level level
|
|
|
+ ) throws URISyntaxException, IOException {
|
|
|
var client = mock(Client.class);
|
|
|
var taskManager = mock(TaskManager.class);
|
|
|
var task = mock(Task.class);
|
|
@@ -150,9 +167,11 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
|
|
|
var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class);
|
|
|
// 2 notifications- the starting message and the failure
|
|
|
verify(client, times(2)).execute(eq(AuditMlNotificationAction.INSTANCE), notificationArg.capture(), any());
|
|
|
- assertThat(notificationArg.getValue().getMessage(), is(message)); // the last message is captured
|
|
|
+ var notification = notificationArg.getValue();
|
|
|
+ assertThat(notification.getMessage(), is(message)); // the last message is captured
|
|
|
+ assertThat(notification.getLevel(), is(level)); // the last message is captured
|
|
|
|
|
|
- var receivedException = (ElasticsearchStatusException) failureRef.get();
|
|
|
+ var receivedException = (ElasticsearchException) failureRef.get();
|
|
|
assertThat(receivedException.toString(), is(onFailureException.toString()));
|
|
|
assertThat(receivedException.status(), is(onFailureException.status()));
|
|
|
assertThat(receivedException.getCause(), is(onFailureException.getCause()));
|