Browse Source

Refactor CachingUsernamePassword realm (#32646)

Refactors the logic of authentication and lookup caching in
`CachingUsernamePasswordRealm`. Nothing changed about
the single-inflight-request or positive caching.
Albert Zaharovits 7 years ago
parent
commit
c567ec4a0f

+ 109 - 114
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java

@@ -5,11 +5,9 @@
  */
 package org.elasticsearch.xpack.security.authc.support;
 
-import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.cache.Cache;
 import org.elasticsearch.common.cache.CacheBuilder;
-import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.ListenableFuture;
@@ -30,7 +28,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 
 public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm implements CachingRealm {
 
-    private final Cache<String, ListenableFuture<Tuple<AuthenticationResult, UserWithHash>>> cache;
+    private final Cache<String, ListenableFuture<UserWithHash>> cache;
     private final ThreadPool threadPool;
     final Hasher cacheHasher;
 
@@ -38,9 +36,9 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
         super(type, config);
         cacheHasher = Hasher.resolve(CachingUsernamePasswordRealmSettings.CACHE_HASH_ALGO_SETTING.get(config.settings()));
         this.threadPool = threadPool;
-        TimeValue ttl = CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING.get(config.settings());
+        final TimeValue ttl = CachingUsernamePasswordRealmSettings.CACHE_TTL_SETTING.get(config.settings());
         if (ttl.getNanos() > 0) {
-            cache = CacheBuilder.<String, ListenableFuture<Tuple<AuthenticationResult, UserWithHash>>>builder()
+            cache = CacheBuilder.<String, ListenableFuture<UserWithHash>>builder()
                     .setExpireAfterWrite(ttl)
                     .setMaximumWeight(CachingUsernamePasswordRealmSettings.CACHE_MAX_USERS_SETTING.get(config.settings()))
                     .build();
@@ -49,6 +47,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
         }
     }
 
+    @Override
     public final void expire(String username) {
         if (cache != null) {
             logger.trace("invalidating cache for user [{}] in realm [{}]", username, name());
@@ -56,6 +55,7 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
         }
     }
 
+    @Override
     public final void expireAll() {
         if (cache != null) {
             logger.trace("invalidating cache for all users in realm [{}]", name());
@@ -72,108 +72,84 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
      */
     @Override
     public final void authenticate(AuthenticationToken authToken, ActionListener<AuthenticationResult> listener) {
-        UsernamePasswordToken token = (UsernamePasswordToken) authToken;
+        final UsernamePasswordToken token = (UsernamePasswordToken) authToken;
         try {
             if (cache == null) {
                 doAuthenticate(token, listener);
             } else {
                 authenticateWithCache(token, listener);
             }
-        } catch (Exception e) {
+        } catch (final Exception e) {
             // each realm should handle exceptions, if we get one here it should be considered fatal
             listener.onFailure(e);
         }
     }
 
+    /**
+     * This validates the {@code token} while making sure there is only one inflight
+     * request to the authentication source. Only successful responses are cached
+     * and any subsequent requests, bearing the <b>same</b> password, will succeed
+     * without reaching to the authentication source. A different password in a
+     * subsequent request, however, will clear the cache and <b>try</b> to reach to
+     * the authentication source.
+     *
+     * @param token The authentication token
+     * @param listener to be called at completion
+     */
     private void authenticateWithCache(UsernamePasswordToken token, ActionListener<AuthenticationResult> listener) {
         try {
-            final SetOnce<User> authenticatedUser = new SetOnce<>();
-            final AtomicBoolean createdAndStartedFuture = new AtomicBoolean(false);
-            final ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> future = cache.computeIfAbsent(token.principal(), k -> {
-                final ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> created = new ListenableFuture<>();
-                if (createdAndStartedFuture.compareAndSet(false, true) == false) {
-                    throw new IllegalStateException("something else already started this. how?");
-                }
-                return created;
+            final AtomicBoolean authenticationInCache = new AtomicBoolean(true);
+            final ListenableFuture<UserWithHash> listenableCacheEntry = cache.computeIfAbsent(token.principal(), k -> {
+                authenticationInCache.set(false);
+                return new ListenableFuture<>();
             });
-
-            if (createdAndStartedFuture.get()) {
-                doAuthenticate(token, ActionListener.wrap(result -> {
-                    if (result.isAuthenticated()) {
-                        final User user = result.getUser();
-                        authenticatedUser.set(user);
-                        final UserWithHash userWithHash = new UserWithHash(user, token.credentials(), cacheHasher);
-                        future.onResponse(new Tuple<>(result, userWithHash));
-                    } else {
-                        future.onResponse(new Tuple<>(result, null));
-                    }
-                }, future::onFailure));
-            }
-
-            future.addListener(ActionListener.wrap(tuple -> {
-                if (tuple != null) {
-                    final UserWithHash userWithHash = tuple.v2();
-                    final boolean performedAuthentication = createdAndStartedFuture.get() && userWithHash != null &&
-                        tuple.v2().user == authenticatedUser.get();
-                    handleResult(future, createdAndStartedFuture.get(), performedAuthentication, token, tuple, listener);
-                } else {
-                    handleFailure(future, createdAndStartedFuture.get(), token, new IllegalStateException("unknown error authenticating"),
-                        listener);
-                }
-            }, e -> handleFailure(future, createdAndStartedFuture.get(), token, e, listener)),
-                threadPool.executor(ThreadPool.Names.GENERIC));
-        } catch (ExecutionException e) {
-            listener.onResponse(AuthenticationResult.unsuccessful("", e));
-        }
-    }
-
-    private void handleResult(ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> future, boolean createdAndStartedFuture,
-                              boolean performedAuthentication, UsernamePasswordToken token,
-                              Tuple<AuthenticationResult, UserWithHash> result, ActionListener<AuthenticationResult> listener) {
-        final AuthenticationResult authResult = result.v1();
-        if (authResult == null) {
-            // this was from a lookup; clear and redo
-            cache.invalidate(token.principal(), future);
-            authenticateWithCache(token, listener);
-        } else if (authResult.isAuthenticated()) {
-            if (performedAuthentication) {
-                listener.onResponse(authResult);
-            } else {
-                UserWithHash userWithHash = result.v2();
-                if (userWithHash.verify(token.credentials())) {
-                    if (userWithHash.user.enabled()) {
-                        User user = userWithHash.user;
-                        logger.debug("realm [{}] authenticated user [{}], with roles [{}]",
-                            name(), token.principal(), user.roles());
+            if (authenticationInCache.get()) {
+                // there is a cached or an inflight authenticate request
+                listenableCacheEntry.addListener(ActionListener.wrap(authenticatedUserWithHash -> {
+                    if (authenticatedUserWithHash != null && authenticatedUserWithHash.verify(token.credentials())) {
+                        // cached credential hash matches the credential hash for this forestalled request
+                        final User user = authenticatedUserWithHash.user;
+                        logger.debug("realm [{}] authenticated user [{}], with roles [{}], from cache", name(), token.principal(),
+                                user.roles());
                         listener.onResponse(AuthenticationResult.success(user));
                     } else {
-                        // re-auth to see if user has been enabled
-                        cache.invalidate(token.principal(), future);
+                        // The inflight request has failed or its credential hash does not match the
+                        // hash of the credential for this forestalled request.
+                        // clear cache and try to reach the authentication source again because password
+                        // might have changed there and the local cached hash got stale
+                        cache.invalidate(token.principal(), listenableCacheEntry);
                         authenticateWithCache(token, listener);
                     }
-                } else {
-                    // could be a password change?
-                    cache.invalidate(token.principal(), future);
+                }, e -> {
+                    // the inflight request failed, so try again, but first (always) make sure cache
+                    // is cleared of the failed authentication
+                    cache.invalidate(token.principal(), listenableCacheEntry);
                     authenticateWithCache(token, listener);
-                }
-            }
-        } else {
-            cache.invalidate(token.principal(), future);
-            if (createdAndStartedFuture) {
-                listener.onResponse(authResult);
+                }), threadPool.executor(ThreadPool.Names.GENERIC));
             } else {
-                authenticateWithCache(token, listener);
+                // attempt authentication against the authentication source
+                doAuthenticate(token, ActionListener.wrap(authResult -> {
+                    if (authResult.isAuthenticated() && authResult.getUser().enabled()) {
+                        // compute the credential hash of this successful authentication request
+                        final UserWithHash userWithHash = new UserWithHash(authResult.getUser(), token.credentials(), cacheHasher);
+                        // notify any forestalled request listeners; they will not reach to the
+                        // authentication request and instead will use this hash for comparison
+                        listenableCacheEntry.onResponse(userWithHash);
+                    } else {
+                        // notify any forestalled request listeners; they will retry the request
+                        listenableCacheEntry.onResponse(null);
+                    }
+                    // notify the listener of the inflight authentication request; this request is not retried
+                    listener.onResponse(authResult);
+                }, e -> {
+                    // notify any staved off listeners; they will retry the request
+                    listenableCacheEntry.onFailure(e);
+                    // notify the listener of the inflight authentication request; this request is not retried
+                    listener.onFailure(e);
+                }));
             }
-        }
-    }
-
-    private void handleFailure(ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> future, boolean createdAndStarted,
-                               UsernamePasswordToken token, Exception e, ActionListener<AuthenticationResult> listener) {
-        cache.invalidate(token.principal(), future);
-        if (createdAndStarted) {
+        } catch (final ExecutionException e) {
             listener.onFailure(e);
-        } else {
-            authenticateWithCache(token, listener);
         }
     }
 
@@ -193,38 +169,57 @@ public abstract class CachingUsernamePasswordRealm extends UsernamePasswordRealm
 
     @Override
     public final void lookupUser(String username, ActionListener<User> listener) {
-        if (cache != null) {
-            try {
-                ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> future = cache.computeIfAbsent(username, key -> {
-                    ListenableFuture<Tuple<AuthenticationResult, UserWithHash>> created = new ListenableFuture<>();
-                    doLookupUser(username, ActionListener.wrap(user -> {
-                        if (user != null) {
-                            UserWithHash userWithHash = new UserWithHash(user, null, null);
-                            created.onResponse(new Tuple<>(null, userWithHash));
-                        } else {
-                            created.onResponse(new Tuple<>(null, null));
-                        }
-                    }, created::onFailure));
-                    return created;
-                });
-
-                future.addListener(ActionListener.wrap(tuple -> {
-                    if (tuple != null) {
-                        if (tuple.v2() == null) {
-                            cache.invalidate(username, future);
-                            listener.onResponse(null);
-                        } else {
-                            listener.onResponse(tuple.v2().user);
-                        }
+        try {
+            if (cache == null) {
+                doLookupUser(username, listener);
+            } else {
+                lookupWithCache(username, listener);
+            }
+        } catch (final Exception e) {
+            // each realm should handle exceptions, if we get one here it should be
+            // considered fatal
+            listener.onFailure(e);
+        }
+    }
+
+    private void lookupWithCache(String username, ActionListener<User> listener) {
+        try {
+            final AtomicBoolean lookupInCache = new AtomicBoolean(true);
+            final ListenableFuture<UserWithHash> listenableCacheEntry = cache.computeIfAbsent(username, key -> {
+                lookupInCache.set(false);
+                return new ListenableFuture<>();
+            });
+            if (false == lookupInCache.get()) {
+                // attempt lookup against the user directory
+                doLookupUser(username, ActionListener.wrap(user -> {
+                    if (user != null) {
+                        // user found
+                        final UserWithHash userWithHash = new UserWithHash(user, null, null);
+                        // notify forestalled request listeners
+                        listenableCacheEntry.onResponse(userWithHash);
                     } else {
-                        listener.onResponse(null);
+                        // user not found, invalidate cache so that subsequent requests are forwarded to
+                        // the user directory
+                        cache.invalidate(username, listenableCacheEntry);
+                        // notify forestalled request listeners
+                        listenableCacheEntry.onResponse(null);
                     }
-                }, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC));
-            } catch (ExecutionException e) {
-                listener.onFailure(e);
+                }, e -> {
+                    // the next request should be forwarded, not halted by a failed lookup attempt
+                    cache.invalidate(username, listenableCacheEntry);
+                    // notify forestalled listeners
+                    listenableCacheEntry.onFailure(e);
+                }));
             }
-        } else {
-            doLookupUser(username, listener);
+            listenableCacheEntry.addListener(ActionListener.wrap(userWithHash -> {
+                if (userWithHash != null) {
+                    listener.onResponse(userWithHash.user);
+                } else {
+                    listener.onResponse(null);
+                }
+            }, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC));
+        } catch (final ExecutionException e) {
+            listener.onFailure(e);
         }
     }