Prechádzať zdrojové kódy

Add infrastructure to mark contexts as system contexts (#23830)

Today we have no way to mark an execution as internal. This commit adds
a simple thread context header that allows executing code in a system context.
This allows intercepting code can make better decisions down the road when
it gets to authentication.
Simon Willnauer 8 rokov pred
rodič
commit
5badf68bd9

+ 14 - 3
core/src/main/java/org/elasticsearch/action/search/RemoteClusterConnection.java

@@ -35,6 +35,7 @@ import org.elasticsearch.common.component.AbstractComponent;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.CancellableThreads;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.ConnectTransportException;
 import org.elasticsearch.transport.ConnectionProfile;
@@ -59,7 +60,6 @@ import java.util.Set;
 import java.util.concurrent.ArrayBlockingQueue;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.Semaphore;
@@ -373,10 +373,19 @@ final class RemoteClusterConnection extends AbstractComponent implements Transpo
                             // here we pass on the connection since we can only close it once the sendRequest returns otherwise
                             // due to the async nature (it will return before it's actually sent) this can cause the request to fail
                             // due to an already closed connection.
-                            transportService.sendRequest(connection,
-                                ClusterStateAction.NAME, request, TransportRequestOptions.EMPTY,
+                            ThreadPool threadPool = transportService.getThreadPool();
+                            ThreadContext threadContext = threadPool.getThreadContext();
+                            TransportService.ContextRestoreResponseHandler<ClusterStateResponse> responseHandler = new TransportService
+                                .ContextRestoreResponseHandler<>(threadContext.newRestorableContext(false),
                                 new SniffClusterStateResponseHandler(transportService, connection, listener, seedNodes,
                                     cancellableThreads));
+                            try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
+                                // we stash any context here since this is an internal execution and should not leak any
+                                // existing context information.
+                                threadContext.markAsSystemContext();
+                                transportService.sendRequest(connection, ClusterStateAction.NAME, request, TransportRequestOptions.EMPTY,
+                                    responseHandler);
+                            }
                             success = true;
                         } finally {
                             if (success == false) {
@@ -445,6 +454,7 @@ final class RemoteClusterConnection extends AbstractComponent implements Transpo
 
             @Override
             public void handleResponse(ClusterStateResponse response) {
+                assert transportService.getThreadPool().getThreadContext().isSystemContext() == false : "context is a system context";
                 try {
                     try (Closeable theConnection = connection) { // the connection is unused - see comment in #collectRemoteNodes
                         // we have to close this connection before we notify listeners - this is mainly needed for test correctness
@@ -483,6 +493,7 @@ final class RemoteClusterConnection extends AbstractComponent implements Transpo
 
             @Override
             public void handleException(TransportException exp) {
+                assert transportService.getThreadPool().getThreadContext().isSystemContext() == false : "context is a system context";
                 logger.warn((Supplier<?>)
                     () -> new ParameterizedMessage("fetching nodes from external cluster {} failed", clusterAlias),
                     exp);

+ 33 - 8
core/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java

@@ -25,7 +25,6 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Setting.Property;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.index.store.Store;
 
 import java.io.Closeable;
 import java.io.IOException;
@@ -75,6 +74,7 @@ public final class ThreadContext implements Closeable, Writeable {
     private static final ThreadContextStruct DEFAULT_CONTEXT = new ThreadContextStruct();
     private final Map<String, String> defaultHeader;
     private final ContextThreadLocal threadLocal;
+    private boolean isSystemContext;
 
     /**
      * Creates a new ThreadContext instance
@@ -317,6 +317,21 @@ public final class ThreadContext implements Closeable, Writeable {
         return threadLocal.get() == DEFAULT_CONTEXT;
     }
 
+    /**
+     * Marks this thread context as an internal system context. This signals that actions in this context are issued
+     * by the system itself rather than by a user action.
+     */
+    public void markAsSystemContext() {
+        threadLocal.set(threadLocal.get().setSystemContext());
+    }
+
+    /**
+     * Returns <code>true</code> iff this context is a system context
+     */
+    public boolean isSystemContext() {
+        return threadLocal.get().isSystemContext;
+    }
+
     /**
      * Returns <code>true</code> if the context is closed, otherwise <code>true</code>
      */
@@ -338,6 +353,7 @@ public final class ThreadContext implements Closeable, Writeable {
         private final Map<String, String> requestHeaders;
         private final Map<String, Object> transientHeaders;
         private final Map<String, List<String>> responseHeaders;
+        private final boolean isSystemContext;
 
         private ThreadContextStruct(StreamInput in) throws IOException {
             final int numRequest = in.readVInt();
@@ -349,27 +365,36 @@ public final class ThreadContext implements Closeable, Writeable {
             this.requestHeaders = requestHeaders;
             this.responseHeaders = in.readMapOfLists(StreamInput::readString, StreamInput::readString);
             this.transientHeaders = Collections.emptyMap();
+            isSystemContext = false; // we never serialize this it's a transient flag
+        }
+
+        private ThreadContextStruct setSystemContext() {
+            if (isSystemContext) {
+                return this;
+            }
+            return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, true);
         }
 
         private ThreadContextStruct(Map<String, String> requestHeaders,
                                     Map<String, List<String>> responseHeaders,
-                                    Map<String, Object> transientHeaders) {
+                                    Map<String, Object> transientHeaders, boolean isSystemContext) {
             this.requestHeaders = requestHeaders;
             this.responseHeaders = responseHeaders;
             this.transientHeaders = transientHeaders;
+            this.isSystemContext = isSystemContext;
         }
 
         /**
          * This represents the default context and it should only ever be called by {@link #DEFAULT_CONTEXT}.
          */
         private ThreadContextStruct() {
-            this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
+            this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), false);
         }
 
         private ThreadContextStruct putRequest(String key, String value) {
             Map<String, String> newRequestHeaders = new HashMap<>(this.requestHeaders);
             putSingleHeader(key, value, newRequestHeaders);
-            return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders);
+            return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, isSystemContext);
         }
 
         private void putSingleHeader(String key, String value, Map<String, String> newHeaders) {
@@ -387,7 +412,7 @@ public final class ThreadContext implements Closeable, Writeable {
                     putSingleHeader(entry.getKey(), entry.getValue(), newHeaders);
                 }
                 newHeaders.putAll(this.requestHeaders);
-                return new ThreadContextStruct(newHeaders, responseHeaders, transientHeaders);
+                return new ThreadContextStruct(newHeaders, responseHeaders, transientHeaders, isSystemContext);
             }
         }
 
@@ -408,7 +433,7 @@ public final class ThreadContext implements Closeable, Writeable {
                     newResponseHeaders.put(key, entry.getValue());
                 }
             }
-            return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders);
+            return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext);
         }
 
         private ThreadContextStruct putResponse(final String key, final String value, final Function<String, String> uniqueValue) {
@@ -432,7 +457,7 @@ public final class ThreadContext implements Closeable, Writeable {
                 newResponseHeaders.put(key, Collections.singletonList(value));
             }
 
-            return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders);
+            return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext);
         }
 
         private ThreadContextStruct putTransient(String key, Object value) {
@@ -440,7 +465,7 @@ public final class ThreadContext implements Closeable, Writeable {
             if (newTransient.putIfAbsent(key, value) != null) {
                 throw new IllegalArgumentException("value for key [" + key + "] already present");
             }
-            return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient);
+            return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, isSystemContext);
         }
 
         boolean isEmpty() {

+ 12 - 1
core/src/main/java/org/elasticsearch/transport/TransportActionProxy.java

@@ -134,11 +134,12 @@ public final class TransportActionProxy {
             true, false, new ProxyRequestHandler<>(service, action, responseSupplier));
     }
 
+    private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/";
     /**
      * Returns the corresponding proxy action for the given action
      */
     public static String getProxyAction(String action) {
-        return "internal:transport/proxy/" + action;
+        return  PROXY_ACTION_PREFIX + action;
     }
 
     /**
@@ -147,4 +148,14 @@ public final class TransportActionProxy {
     public static TransportRequest wrapRequest(DiscoveryNode node, TransportRequest request) {
         return new ProxyRequest<>(request, node);
     }
+
+    /**
+     * Unwraps a proxy request and returns the original request
+     */
+    public static TransportRequest unwrapRequest(TransportRequest request) {
+        if (request instanceof ProxyRequest) {
+            return ((ProxyRequest)request).wrapped;
+        }
+        return request;
+    }
 }

+ 19 - 0
core/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java

@@ -430,11 +430,16 @@ public class ThreadContextTests extends ESTestCase {
             // create a abstract runnable, add headers and transient objects and verify in the methods
             try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
                 threadContext.putHeader("foo", "bar");
+                boolean systemContext = randomBoolean();
+                if (systemContext) {
+                    threadContext.markAsSystemContext();
+                }
                 threadContext.putTransient("foo", "bar_transient");
                 withContext = threadContext.preserveContext(new AbstractRunnable() {
 
                     @Override
                     public void onAfter() {
+                        assertEquals(systemContext, threadContext.isSystemContext());
                         assertEquals("bar", threadContext.getHeader("foo"));
                         assertEquals("bar_transient", threadContext.getTransient("foo"));
                         assertNotNull(threadContext.getTransient("failure"));
@@ -445,6 +450,7 @@ public class ThreadContextTests extends ESTestCase {
 
                     @Override
                     public void onFailure(Exception e) {
+                        assertEquals(systemContext, threadContext.isSystemContext());
                         assertEquals("exception from doRun", e.getMessage());
                         assertEquals("bar", threadContext.getHeader("foo"));
                         assertEquals("bar_transient", threadContext.getTransient("foo"));
@@ -454,6 +460,7 @@ public class ThreadContextTests extends ESTestCase {
 
                     @Override
                     protected void doRun() throws Exception {
+                        assertEquals(systemContext, threadContext.isSystemContext());
                         assertEquals("bar", threadContext.getHeader("foo"));
                         assertEquals("bar_transient", threadContext.getTransient("foo"));
                         assertFalse(threadContext.isDefaultContext());
@@ -594,6 +601,18 @@ public class ThreadContextTests extends ESTestCase {
         }
     }
 
+    public void testMarkAsSystemContext() throws IOException {
+        try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
+            assertFalse(threadContext.isSystemContext());
+            try(ThreadContext.StoredContext context = threadContext.stashContext()){
+                assertFalse(threadContext.isSystemContext());
+                threadContext.markAsSystemContext();
+                assertTrue(threadContext.isSystemContext());
+            }
+            assertFalse(threadContext.isSystemContext());
+        }
+    }
+
     /**
      * Sometimes wraps a Runnable in an AbstractRunnable.
      */

+ 12 - 0
core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java

@@ -246,4 +246,16 @@ public class TransportActionProxyTests extends ESTestCase {
         }
     }
 
+    public void testGetAction() {
+        String action = "foo/bar";
+        String proxyAction = TransportActionProxy.getProxyAction(action);
+        assertTrue(proxyAction.endsWith(action));
+        assertEquals("internal:transport/proxy/foo/bar", proxyAction);
+    }
+
+    public void testUnwrap() {
+        TransportRequest transportRequest = TransportActionProxy.wrapRequest(nodeA, TransportService.HandshakeRequest.INSTANCE);
+        assertTrue(transportRequest instanceof TransportActionProxy.ProxyRequest);
+        assertSame(TransportService.HandshakeRequest.INSTANCE, TransportActionProxy.unwrapRequest(transportRequest));
+    }
 }