Просмотр исходного кода

[ML] Improve EIS auth call logs and fix revocation bug (#132546) (#132690)

* Fixing revoking and adding logs

* Fixing tests

* Update docs/changelog/132546.yaml

* [CI] Auto commit changes from spotless

* Addressing feedback

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Jonathan Buttner 2 месяцев назад
Родитель
Сommit
ab93718dc8

+ 5 - 0
docs/changelog/132546.yaml

@@ -0,0 +1,5 @@
+pr: 132546
+summary: Improve EIS auth call logs and fix revocation bug
+area: Machine Learning
+type: bug
+issues: []

+ 12 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java

@@ -243,14 +243,15 @@ public class ElasticInferenceServiceAuthorizationHandler implements Closeable {
     }
 
     private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) {
-        logger.debug("Received authorization response");
-        var authorizedTaskTypesAndModels = authorizedContent.get().taskTypesAndModels.merge(auth)
-            .newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
+        logger.debug(() -> Strings.format("Received authorization response, %s", auth));
+
+        var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes));
+        logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels));
 
         // recalculate which default config ids and models are authorized now
-        var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(auth);
+        var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels);
 
-        var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, auth);
+        var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels);
         var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds);
         authorizedContent.set(
             new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects)
@@ -337,7 +338,12 @@ public class ElasticInferenceServiceAuthorizationHandler implements Closeable {
             firstAuthorizationCompletedLatch.countDown();
         });
 
-        logger.debug("Synchronizing default inference endpoints");
+        logger.debug(
+            () -> Strings.format(
+                "Synchronizing default inference endpoints, attempting to remove ids: %s",
+                unauthorizedDefaultInferenceEndpointIds
+            )
+        );
         modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener);
     }
 }

+ 12 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java

@@ -161,4 +161,16 @@ public class ElasticInferenceServiceAuthorizationModel {
     public int hashCode() {
         return Objects.hash(taskTypeToModels, authorizedTaskTypes, authorizedModelIds);
     }
+
+    @Override
+    public String toString() {
+        return "{"
+            + "taskTypeToModels="
+            + taskTypeToModels
+            + ", authorizedTaskTypes="
+            + authorizedTaskTypes
+            + ", authorizedModelIds="
+            + authorizedModelIds
+            + '}';
+    }
 }

+ 14 - 13
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandler.java

@@ -9,7 +9,8 @@ package org.elasticsearch.xpack.inference.services.elastic.authorization;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
-import org.elasticsearch.ElasticsearchWrapperException;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.core.Nullable;
@@ -86,25 +87,25 @@ public class ElasticInferenceServiceAuthorizationRequestHandler {
 
             ActionListener<InferenceServiceResults> newListener = ActionListener.wrap(results -> {
                 if (results instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
+                    logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
                     listener.onResponse(ElasticInferenceServiceAuthorizationModel.of(authResponseEntity));
                 } else {
-                    logger.warn(
-                        Strings.format(
-                            FAILED_TO_RETRIEVE_MESSAGE + " Received an invalid response type: %s",
-                            results.getClass().getSimpleName()
-                        )
+                    var errorMessage = Strings.format(
+                        "%s Received an invalid response type from the Elastic Inference Service: %s",
+                        FAILED_TO_RETRIEVE_MESSAGE,
+                        results.getClass().getSimpleName()
                     );
-                    listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
+
+                    logger.warn(errorMessage);
+                    listener.onFailure(new ElasticsearchException(errorMessage));
                 }
                 requestCompleteLatch.countDown();
             }, e -> {
-                Throwable exception = e;
-                if (e instanceof ElasticsearchWrapperException wrapperException) {
-                    exception = wrapperException.getCause();
-                }
+                // unwrap because it's likely a retry exception
+                var exception = ExceptionsHelper.unwrapCause(e);
 
-                logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception));
-                listener.onResponse(ElasticInferenceServiceAuthorizationModel.newDisabledService());
+                logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
+                listener.onFailure(e);
                 requestCompleteLatch.countDown();
             });
 

+ 16 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.services.elastic.response;
 
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -14,6 +15,8 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContent;
@@ -39,6 +42,9 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg
 public class ElasticInferenceServiceAuthorizationResponseEntity implements InferenceServiceResults {
 
     public static final String NAME = "elastic_inference_service_auth_results";
+
+    private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationResponseEntity.class);
+    private static final String AUTH_FIELD_NAME = "authorized_models";
     private static final Map<String, TaskType> ELASTIC_INFERENCE_SERVICE_TASK_TYPE_MAPPING = Map.of(
         "embed/text/sparse",
         TaskType.SPARSE_EMBEDDING,
@@ -107,6 +113,11 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
 
             return builder;
         }
+
+        @Override
+        public String toString() {
+            return Strings.format("{modelName='%s', taskTypes='%s'}", modelName, taskTypes);
+        }
     }
 
     private final List<AuthorizedModel> authorizedModels;
@@ -138,6 +149,11 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
         return authorizedModels;
     }
 
+    @Override
+    public String toString() {
+        return authorizedModels.stream().map(AuthorizedModel::toString).collect(Collectors.joining(", "));
+    }
+
     @Override
     public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
         throw new UnsupportedOperationException();

+ 76 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

@@ -67,6 +67,78 @@ public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNo
         modelRegistry = getInstanceFromNode(ModelRegistry.class);
     }
 
+    public void testSecondAuthResultRevokesAuthorization() throws Exception {
+        var callbackCount = new AtomicInteger(0);
+        // we're only interested in two authorization calls which is why I'm using a value of 2 here
+        var latch = new CountDownLatch(2);
+        final AtomicReference<ElasticInferenceServiceAuthorizationHandler> handlerRef = new AtomicReference<>();
+
+        Runnable callback = () -> {
+            // the first authorization response contains a streaming task so we're expecting to support streaming here
+            if (callbackCount.incrementAndGet() == 1) {
+                assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
+            }
+            latch.countDown();
+
+            // we only want to run the tasks twice, so advance the time on the queue
+            // which flags the scheduled authorization request to be ready to run
+            if (callbackCount.get() == 1) {
+                taskQueue.advanceTime();
+            } else {
+                try {
+                    handlerRef.get().close();
+                } catch (IOException e) {
+                    // ignore
+                }
+            }
+        };
+
+        var requestHandler = mockAuthorizationRequestHandler(
+            ElasticInferenceServiceAuthorizationModel.of(
+                new ElasticInferenceServiceAuthorizationResponseEntity(
+                    List.of(
+                        new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
+                            "rainbow-sprinkles",
+                            EnumSet.of(TaskType.CHAT_COMPLETION)
+                        )
+                    )
+                )
+            ),
+            ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of()))
+        );
+
+        handlerRef.set(
+            new ElasticInferenceServiceAuthorizationHandler(
+                createWithEmptySettings(taskQueue.getThreadPool()),
+                modelRegistry,
+                requestHandler,
+                initDefaultEndpoints(),
+                EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION),
+                null,
+                mock(Sender.class),
+                ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true),
+                callback
+            )
+        );
+
+        var handler = handlerRef.get();
+        handler.init();
+        taskQueue.runAllRunnableTasks();
+        latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS);
+
+        // this should be after we've received both authorization responses, the second response will revoke authorization
+
+        assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
+        assertThat(handler.defaultConfigIds(), is(List.of()));
+        assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class)));
+
+        PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
+        handler.defaultConfigs(listener);
+
+        var configs = listener.actionGet();
+        assertThat(configs.size(), is(0));
+    }
+
     public void testSendsAnAuthorizationRequestTwice() throws Exception {
         var callbackCount = new AtomicInteger(0);
         // we're only interested in two authorization calls which is why I'm using a value of 2 here
@@ -104,6 +176,10 @@ public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNo
             ElasticInferenceServiceAuthorizationModel.of(
                 new ElasticInferenceServiceAuthorizationResponseEntity(
                     List.of(
+                        new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
+                            "abc",
+                            EnumSet.of(TaskType.SPARSE_EMBEDDING)
+                        ),
                         new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
                             "rainbow-sprinkles",
                             EnumSet.of(TaskType.CHAT_COMPLETION)

+ 17 - 27
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationRequestHandlerTests.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference.services.elastic.authorization;
 
 import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.Settings;
@@ -18,6 +19,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.http.MockResponse;
 import org.elasticsearch.test.http.MockWebServer;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentParseException;
 import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -38,13 +40,14 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
 import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
 import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
 import static org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender.MAX_RETIES;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends ESTestCase {
@@ -135,22 +138,17 @@ public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends EST
             PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
             authHandler.getAuthorization(listener, sender);
 
-            var authResponse = listener.actionGet(TIMEOUT);
-            assertTrue(authResponse.getAuthorizedTaskTypes().isEmpty());
-            assertTrue(authResponse.getAuthorizedModelIds().isEmpty());
-            assertFalse(authResponse.isAuthorized());
+            var exception = expectThrows(XContentParseException.class, () -> listener.actionGet(TIMEOUT));
+            assertThat(exception.getMessage(), containsString("failed to parse field [models]"));
 
-            var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
-            verify(logger).warn(loggerArgsCaptor.capture());
-            var message = loggerArgsCaptor.getValue();
-            assertThat(
-                message,
-                is(
-                    "Failed to retrieve the authorization information from the Elastic Inference Service."
-                        + " Encountered an exception: org.elasticsearch.xcontent.XContentParseException: [4:28] "
-                        + "[ElasticInferenceServiceAuthorizationResponseEntity] failed to parse field [models]"
-                )
-            );
+            var stringCaptor = ArgumentCaptor.forClass(String.class);
+            var exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
+            verify(logger).warn(stringCaptor.capture(), exceptionCaptor.capture());
+            var message = stringCaptor.getValue();
+            assertThat(message, containsString("failed to parse field [models]"));
+
+            var capturedException = exceptionCaptor.getValue();
+            assertThat(capturedException, instanceOf(XContentParseException.class));
         }
     }
 
@@ -196,7 +194,6 @@ public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends EST
 
             var message = loggerArgsCaptor.getValue();
             assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
-            verifyNoMoreInteractions(logger);
         }
     }
 
@@ -230,7 +227,6 @@ public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends EST
 
             var message = loggerArgsCaptor.getValue();
             assertThat(message, is("Retrieving authorization information from the Elastic Inference Service."));
-            verifyNoMoreInteractions(logger);
         }
     }
 
@@ -252,20 +248,14 @@ public class ElasticInferenceServiceAuthorizationRequestHandlerTests extends EST
             PlainActionFuture<ElasticInferenceServiceAuthorizationModel> listener = new PlainActionFuture<>();
 
             authHandler.getAuthorization(listener, sender);
-            var result = listener.actionGet(TIMEOUT);
+            var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
-            assertThat(result, is(ElasticInferenceServiceAuthorizationModel.newDisabledService()));
+            assertThat(exception.getMessage(), containsString("Received an invalid response type from the Elastic Inference Service"));
 
             var loggerArgsCaptor = ArgumentCaptor.forClass(String.class);
             verify(logger).warn(loggerArgsCaptor.capture());
             var message = loggerArgsCaptor.getValue();
-            assertThat(
-                message,
-                is(
-                    "Failed to retrieve the authorization information from the Elastic Inference Service."
-                        + " Received an invalid response type: ChatCompletionResults"
-                )
-            );
+            assertThat(message, containsString("Failed to retrieve the authorization information from the Elastic Inference Service."));
         }
 
     }