Browse Source

[ML] Log model download cancelled message as a warning not error (#110776)

David Kyle 1 year ago
parent
commit
48e6e0f86c

+ 17 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/Level.java

@@ -9,9 +9,23 @@ package org.elasticsearch.xpack.core.common.notifications;
 import java.util.Locale;
 
 public enum Level {
-    INFO,
-    WARNING,
-    ERROR;
+    INFO {
+        public org.apache.logging.log4j.Level log4jLevel() {
+            return org.apache.logging.log4j.Level.INFO;
+        }
+    },
+    WARNING {
+        public org.apache.logging.log4j.Level log4jLevel() {
+            return org.apache.logging.log4j.Level.WARN;
+        }
+    },
+    ERROR {
+        public org.apache.logging.log4j.Level log4jLevel() {
+            return org.apache.logging.log4j.Level.ERROR;
+        }
+    };
+
+    public abstract org.apache.logging.log4j.Level log4jLevel();
 
     /**
      * Case-insensitive from string method.

+ 6 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/common/notifications/LevelTests.java

@@ -32,4 +32,10 @@ public class LevelTests extends ESTestCase {
         assertThat(Level.WARNING.ordinal(), equalTo(1));
         assertThat(Level.ERROR.ordinal(), equalTo(2));
     }
+
+    public void testLog4JLevel() {
+        assertThat(Level.INFO.log4jLevel(), equalTo(org.apache.logging.log4j.Level.INFO));
+        assertThat(Level.WARNING.log4jLevel(), equalTo(org.apache.logging.log4j.Level.WARN));
+        assertThat(Level.ERROR.log4jLevel(), equalTo(org.apache.logging.log4j.Level.ERROR));
+    }
 }

+ 26 - 19
x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java

@@ -27,6 +27,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskAwareRequest;
+import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -141,26 +142,29 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
         try {
             final long relativeStartNanos = System.nanoTime();
 
-            logAndWriteNotificationAtInfo(auditClient, modelId, "starting model import");
+            logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO);
 
             modelImporter.doImport();
 
             final long totalRuntimeNanos = System.nanoTime() - relativeStartNanos;
-            logAndWriteNotificationAtInfo(
+            logAndWriteNotificationAtLevel(
                 auditClient,
                 modelId,
-                format("finished model import after [%d] seconds", TimeUnit.NANOSECONDS.toSeconds(totalRuntimeNanos))
+                format("finished model import after [%d] seconds", TimeUnit.NANOSECONDS.toSeconds(totalRuntimeNanos)),
+                Level.INFO
             );
+        } catch (TaskCancelledException e) {
+            recordError(auditClient, modelId, exceptionRef, e, Level.WARNING);
         } catch (ElasticsearchException e) {
-            recordError(auditClient, modelId, exceptionRef, e);
+            recordError(auditClient, modelId, exceptionRef, e, Level.ERROR);
         } catch (MalformedURLException e) {
-            recordError(auditClient, modelId, "an invalid URL", exceptionRef, e, RestStatus.INTERNAL_SERVER_ERROR);
+            recordError(auditClient, modelId, "an invalid URL", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR);
         } catch (URISyntaxException e) {
-            recordError(auditClient, modelId, "an invalid URL syntax", exceptionRef, e, RestStatus.INTERNAL_SERVER_ERROR);
+            recordError(auditClient, modelId, "an invalid URL syntax", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR);
         } catch (IOException e) {
-            recordError(auditClient, modelId, "an IOException", exceptionRef, e, RestStatus.SERVICE_UNAVAILABLE);
+            recordError(auditClient, modelId, "an IOException", exceptionRef, e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE);
         } catch (Exception e) {
-            recordError(auditClient, modelId, "an Exception", exceptionRef, e, RestStatus.INTERNAL_SERVER_ERROR);
+            recordError(auditClient, modelId, "an Exception", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR);
         } finally {
             taskManager.unregister(task);
 
@@ -199,8 +203,15 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
         }, false);
     }
 
-    private static void recordError(Client client, String modelId, AtomicReference<Exception> exceptionRef, ElasticsearchException e) {
-        logAndWriteNotificationAtError(client, modelId, e.getDetailedMessage());
+    private static void recordError(
+        Client client,
+        String modelId,
+        AtomicReference<Exception> exceptionRef,
+        ElasticsearchException e,
+        Level level
+    ) {
+        String message = format("Model importing failed due to [%s]", e.getDetailedMessage());
+        logAndWriteNotificationAtLevel(client, modelId, message, level);
         exceptionRef.set(e);
     }
 
@@ -210,21 +221,17 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
         String failureType,
         AtomicReference<Exception> exceptionRef,
         Exception e,
+        Level level,
         RestStatus status
     ) {
         String message = format("Model importing failed due to %s [%s]", failureType, e);
-        logAndWriteNotificationAtError(client, modelId, message);
+        logAndWriteNotificationAtLevel(client, modelId, message, level);
         exceptionRef.set(new ElasticsearchStatusException(message, status, e));
     }
 
-    private static void logAndWriteNotificationAtError(Client client, String modelId, String message) {
-        writeNotification(client, modelId, message, Level.ERROR);
-        logger.error(format("[%s] %s", modelId, message));
-    }
-
-    private static void logAndWriteNotificationAtInfo(Client client, String modelId, String message) {
-        writeNotification(client, modelId, message, Level.INFO);
-        logger.info(format("[%s] %s", modelId, message));
+    private static void logAndWriteNotificationAtLevel(Client client, String modelId, String message, Level level) {
+        writeNotification(client, modelId, message, level);
+        logger.log(level.log4jLevel(), format("[%s] %s", modelId, message));
     }
 
     private static void writeNotification(Client client, String modelId, String message, Level level) {

+ 32 - 13
x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java

@@ -7,14 +7,17 @@
 
 package org.elasticsearch.xpack.ml.packageloader.action;
 
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.common.notifications.Level;
 import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
 import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
@@ -62,36 +65,44 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
 
     public void testSendsErrorNotificationForInternalError() throws URISyntaxException, IOException {
         ElasticsearchStatusException exception = new ElasticsearchStatusException("exception", RestStatus.INTERNAL_SERVER_ERROR);
+        String message = format("Model importing failed due to [%s]", exception.toString());
 
-        assertUploadCallsOnFailure(exception, exception.toString());
+        assertUploadCallsOnFailure(exception, message, Level.ERROR);
     }
 
     public void testSendsErrorNotificationForMalformedURL() throws URISyntaxException, IOException {
         MalformedURLException exception = new MalformedURLException("exception");
         String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL", exception.toString());
 
-        assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR);
+        assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR);
     }
 
     public void testSendsErrorNotificationForURISyntax() throws URISyntaxException, IOException {
         URISyntaxException exception = mock(URISyntaxException.class);
         String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL syntax", exception.toString());
 
-        assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR);
+        assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR);
     }
 
     public void testSendsErrorNotificationForIOException() throws URISyntaxException, IOException {
         IOException exception = mock(IOException.class);
         String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an IOException", exception.toString());
 
-        assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE);
+        assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE, Level.ERROR);
     }
 
     public void testSendsErrorNotificationForException() throws URISyntaxException, IOException {
         RuntimeException exception = mock(RuntimeException.class);
         String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an Exception", exception.toString());
 
-        assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR);
+        assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR);
+    }
+
+    public void testSendsWarningNotificationForTaskCancelledException() throws URISyntaxException, IOException {
+        TaskCancelledException exception = new TaskCancelledException("cancelled");
+        String message = format("Model importing failed due to [%s]", exception.toString());
+
+        assertUploadCallsOnFailure(exception, message, Level.WARNING);
     }
 
     public void testCallsOnResponseWithAcknowledgedResponse() throws URISyntaxException, IOException {
@@ -123,18 +134,24 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
         );
     }
 
-    private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status) throws URISyntaxException, IOException {
+    private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws URISyntaxException,
+        IOException {
         var esStatusException = new ElasticsearchStatusException(message, status, exception);
 
-        assertNotificationAndOnFailure(exception, esStatusException, message);
+        assertNotificationAndOnFailure(exception, esStatusException, message, level);
     }
 
-    private void assertUploadCallsOnFailure(ElasticsearchStatusException exception, String message) throws URISyntaxException, IOException {
-        assertNotificationAndOnFailure(exception, exception, message);
+    private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws URISyntaxException,
+        IOException {
+        assertNotificationAndOnFailure(exception, exception, message, level);
     }
 
-    private void assertNotificationAndOnFailure(Exception thrownException, ElasticsearchStatusException onFailureException, String message)
-        throws URISyntaxException, IOException {
+    private void assertNotificationAndOnFailure(
+        Exception thrownException,
+        ElasticsearchException onFailureException,
+        String message,
+        Level level
+    ) throws URISyntaxException, IOException {
         var client = mock(Client.class);
         var taskManager = mock(TaskManager.class);
         var task = mock(Task.class);
@@ -150,9 +167,11 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase {
         var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class);
         // 2 notifications- the starting message and the failure
         verify(client, times(2)).execute(eq(AuditMlNotificationAction.INSTANCE), notificationArg.capture(), any());
-        assertThat(notificationArg.getValue().getMessage(), is(message)); // the last message is captured
+        var notification = notificationArg.getValue();
+        assertThat(notification.getMessage(), is(message)); // the last message is captured
+        assertThat(notification.getLevel(), is(level)); // the last message is captured
 
-        var receivedException = (ElasticsearchStatusException) failureRef.get();
+        var receivedException = (ElasticsearchException) failureRef.get();
         assertThat(receivedException.toString(), is(onFailureException.toString()));
         assertThat(receivedException.status(), is(onFailureException.status()));
         assertThat(receivedException.getCause(), is(onFailureException.getCause()));