Browse Source

Fix Token Service retry mechanism (#39639)

Fixes several errors of the token retry logic:

* not checking for backoff.hasNext() before calling backoff.next()
* checking for backoff.hasNext() without calling backoff.next()
* not preserving the context on the retry
* calling scheduleWithFixedDelay instead of schedule
Albert Zaharovits 6 years ago
parent
commit
e3981b7576

+ 91 - 96
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java

@@ -200,7 +200,6 @@ public final class TokenService {
                         SecurityIndexManager securityIndex, ClusterService clusterService) throws GeneralSecurityException {
         byte[] saltArr = new byte[SALT_BYTES];
         secureRandom.nextBytes(saltArr);
-
         final SecureString tokenPassphrase = generateTokenKey();
         this.settings = settings;
         this.clock = clock.withZone(ZoneOffset.UTC);
@@ -681,19 +680,19 @@ public final class TokenService {
                         if (retryTokenDocIds.isEmpty() == false) {
                             if (backoff.hasNext()) {
                                 logger.debug("failed to invalidate [{}] tokens out of [{}], retrying to invalidate these too",
-                                    retryTokenDocIds.size(), tokenIds.size());
-                                TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated,
-                                    failedRequestResponses);
-                                client.threadPool().schedule(
-                                    () -> indexInvalidation(retryTokenDocIds, listener, backoff, srcPrefix, incompleteResult),
-                                    backoff.next(), GENERIC);
+                                        retryTokenDocIds.size(), tokenIds.size());
+                                final TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated,
+                                        previouslyInvalidated, failedRequestResponses);
+                                final Runnable retryWithContextRunnable = client.threadPool().getThreadContext().preserveContext(
+                                        () -> indexInvalidation(retryTokenDocIds, listener, backoff, srcPrefix, incompleteResult));
+                                client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
                             } else {
-                                logger.warn("failed to invalidate [{}] tokens out of [{}] after all retries",
-                                    retryTokenDocIds.size(), tokenIds.size());
+                                logger.warn("failed to invalidate [{}] tokens out of [{}] after all retries", retryTokenDocIds.size(),
+                                        tokenIds.size());
                             }
                         } else {
-                            TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated,
-                                failedRequestResponses);
+                            final TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated,
+                                    failedRequestResponses);
                             listener.onResponse(result);
                         }
                     }, e -> {
@@ -701,8 +700,9 @@ public final class TokenService {
                         traceLog("invalidate tokens", cause);
                         if (isShardNotAvailableException(cause) && backoff.hasNext()) {
                             logger.debug("failed to invalidate tokens, retrying ");
-                            client.threadPool().schedule(
-                                () -> indexInvalidation(tokenIds, listener, backoff, srcPrefix, previousResult), backoff.next(), GENERIC);
+                            final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
+                                    .preserveContext(() -> indexInvalidation(tokenIds, listener, backoff, srcPrefix, previousResult));
+                            client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
                         } else {
                             listener.onFailure(e);
                         }
@@ -734,34 +734,39 @@ public final class TokenService {
      */
     private void findTokenFromRefreshToken(String refreshToken, ActionListener<SearchResponse> listener,
                                            Iterator<TimeValue> backoff) {
-        SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
-            .setQuery(QueryBuilders.boolQuery()
-                .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE))
-                .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken)))
-            .seqNoAndPrimaryTerm(true)
-            .request();
-
+        final Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("find token by refresh token", refreshToken, ex));
+        final Consumer<Exception> maybeRetryOnFailure = ex -> {
+            if (backoff.hasNext()) {
+                final TimeValue backofTimeValue = backoff.next();
+                logger.debug("retrying after [" + backofTimeValue + "] back off");
+                final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
+                        .preserveContext(() -> findTokenFromRefreshToken(refreshToken, listener, backoff));
+                client.threadPool().schedule(retryWithContextRunnable, backofTimeValue, GENERIC);
+            } else {
+                logger.warn("failed to find token from refresh token after all retries");
+                onFailure.accept(ex);
+            }
+        };
         final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze();
         if (frozenSecurityIndex.indexExists() == false) {
             logger.warn("security index does not exist therefore refresh token [{}] cannot be validated", refreshToken);
             listener.onFailure(invalidGrantException("could not refresh the requested token"));
         } else if (frozenSecurityIndex.isAvailable() == false) {
             logger.debug("security index is not available to find token from refresh token, retrying");
-            client.threadPool().scheduleWithFixedDelay(
-                () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC);
+            maybeRetryOnFailure.accept(invalidGrantException("could not refresh the requested token"));
         } else {
-            Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex));
+            final SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
+                    .setQuery(QueryBuilders.boolQuery()
+                            .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE))
+                            .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken)))
+                    .seqNoAndPrimaryTerm(true)
+                    .request();
             securityIndex.checkIndexVersionThenExecute(listener::onFailure, () ->
                 executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request,
                     ActionListener.<SearchResponse>wrap(searchResponse -> {
                         if (searchResponse.isTimedOut()) {
-                            if (backoff.hasNext()) {
-                                client.threadPool().scheduleWithFixedDelay(
-                                    () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC);
-                            } else {
-                                logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken);
-                                onFailure.accept(invalidGrantException("could not refresh the requested token"));
-                            }
+                            logger.debug("find token from refresh token response timed out, retrying");
+                            maybeRetryOnFailure.accept(invalidGrantException("could not refresh the requested token"));
                         } else if (searchResponse.getHits().getHits().length < 1) {
                             logger.warn("could not find token document with refresh_token [{}]", refreshToken);
                             onFailure.accept(invalidGrantException("could not refresh the requested token"));
@@ -772,14 +777,8 @@ public final class TokenService {
                         }
                     }, e -> {
                         if (isShardNotAvailableException(e)) {
-                            if (backoff.hasNext()) {
-                                logger.debug("failed to find token for refresh token [{}], retrying", refreshToken);
-                                client.threadPool().scheduleWithFixedDelay(
-                                    () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC);
-                            } else {
-                                logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken);
-                                onFailure.accept(invalidGrantException("could not refresh the requested token"));
-                            }
+                            logger.debug("find token from refresh token request failed because of unavailable shards, retrying");
+                            maybeRetryOnFailure.accept(invalidGrantException("could not refresh the requested token"));
                         } else {
                             onFailure.accept(e);
                         }
@@ -804,7 +803,7 @@ public final class TokenService {
     private void innerRefresh(String tokenDocId, Map<String, Object> source, long seqNo, long primaryTerm, Authentication clientAuth,
                               ActionListener<Tuple<UserToken, String>> listener, Iterator<TimeValue> backoff, Instant refreshRequested) {
         logger.debug("Attempting to refresh token [{}]", tokenDocId);
-        Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex));
+        final Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex));
         final Optional<ElasticsearchSecurityException> invalidSource = checkTokenDocForRefresh(source, clientAuth);
         if (invalidSource.isPresent()) {
             onFailure.accept(invalidSource.get());
@@ -815,6 +814,19 @@ public final class TokenService {
                 logger.debug("Token document [{}] was recently refreshed, attempting to reuse [{}] for returning an " +
                     "access token and refresh token", tokenDocId, supersedingTokenDocId);
                 final ActionListener<GetResponse> getSupersedingListener = new ActionListener<GetResponse>() {
+                    private final Consumer<Exception> maybeRetryOnFailure = ex -> {
+                        if (backoff.hasNext()) {
+                            final TimeValue backofTimeValue = backoff.next();
+                            logger.debug("retrying after [" + backofTimeValue + "] back off");
+                            final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
+                                    .preserveContext(() -> getTokenDocAsync(supersedingTokenDocId, this));
+                            client.threadPool().schedule(retryWithContextRunnable, backofTimeValue, GENERIC);
+                        } else {
+                            logger.warn("back off retries exhausted");
+                            onFailure.accept(ex);
+                        }
+                    };
+
                     @Override
                     public void onResponse(GetResponse response) {
                         if (response.isExists()) {
@@ -826,30 +838,20 @@ public final class TokenService {
                                 (Map<String, Object>) supersedingTokenSource.get("refresh_token");
                             final String supersedingRefreshTokenValue = (String) supersedingRefreshTokenSrc.get("token");
                             reIssueTokens(supersedingUserTokenSource, supersedingRefreshTokenValue, listener);
-                        } else if (backoff.hasNext()) {
+                        } else {
                             // We retry this since the creation of the superseding token document might already be in flight but not
                             // yet completed, triggered by a refresh request that came a few milliseconds ago
                             logger.info("could not find superseding token document [{}] for token document [{}], retrying",
-                                supersedingTokenDocId, tokenDocId);
-                            client.threadPool().schedule(() -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC);
-                        } else {
-                            logger.warn("could not find superseding token document [{}] for token document [{}] after all retries",
-                                supersedingTokenDocId, tokenDocId);
-                            onFailure.accept(invalidGrantException("could not refresh the requested token"));
+                                    supersedingTokenDocId, tokenDocId);
+                            maybeRetryOnFailure.accept(invalidGrantException("could not refresh the requested token"));
                         }
                     }
 
                     @Override
                     public void onFailure(Exception e) {
                         if (isShardNotAvailableException(e)) {
-                            if (backoff.hasNext()) {
-                                logger.info("could not find superseding token document [{}] for refresh, retrying", supersedingTokenDocId);
-                                client.threadPool().schedule(
-                                    () -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC);
-                            } else {
-                                logger.warn("could not find token document [{}] for refresh after all retries", supersedingTokenDocId);
-                                onFailure.accept(invalidGrantException("could not refresh the requested token"));
-                            }
+                            logger.info("could not find superseding token document [{}] for refresh, retrying", supersedingTokenDocId);
+                            maybeRetryOnFailure.accept(invalidGrantException("could not refresh the requested token"));
                         } else {
                             logger.warn("could not find superseding token document [{}] for refresh", supersedingTokenDocId);
                             onFailure.accept(invalidGrantException("could not refresh the requested token"));
@@ -897,10 +899,10 @@ public final class TokenService {
                             } else if (backoff.hasNext()) {
                                 logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying",
                                     tokenDocId, updateResponse.getResult());
-                                client.threadPool().schedule(
-                                    () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff,
-                                        refreshRequested),
-                                    backoff.next(), GENERIC);
+                                final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
+                                        .preserveContext(() -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener,
+                                                backoff, refreshRequested));
+                                client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
                             } else {
                                 logger.info("failed to update the original token document [{}] after all retries, " +
                                     "the update result was [{}]. ", tokenDocId, updateResponse.getResult());
@@ -910,51 +912,44 @@ public final class TokenService {
                             Throwable cause = ExceptionsHelper.unwrapCause(e);
                             if (cause instanceof VersionConflictEngineException) {
                                 //The document has been updated by another thread, get it again.
-                                if (backoff.hasNext()) {
-                                    logger.debug("version conflict while updating document [{}], attempting to get it again",
-                                        tokenDocId);
-                                    final ActionListener<GetResponse> getListener = new ActionListener<GetResponse>() {
-                                        @Override
-                                        public void onResponse(GetResponse response) {
-                                            if (response.isExists()) {
-                                                innerRefresh(tokenDocId, response.getSource(), response.getSeqNo(),
-                                                    response.getPrimaryTerm(), clientAuth, listener, backoff, refreshRequested);
-                                            } else {
-                                                logger.warn("could not find token document [{}] for refresh", tokenDocId);
-                                                onFailure.accept(invalidGrantException("could not refresh the requested token"));
-                                            }
+                                logger.debug("version conflict while updating document [{}], attempting to get it again", tokenDocId);
+                                final ActionListener<GetResponse> getListener = new ActionListener<GetResponse>() {
+                                    @Override
+                                    public void onResponse(GetResponse response) {
+                                        if (response.isExists()) {
+                                            innerRefresh(tokenDocId, response.getSource(), response.getSeqNo(), response.getPrimaryTerm(),
+                                                    clientAuth, listener, backoff, refreshRequested);
+                                        } else {
+                                            logger.warn("could not find token document [{}] for refresh", tokenDocId);
+                                            onFailure.accept(invalidGrantException("could not refresh the requested token"));
                                         }
-
-                                        @Override
-                                        public void onFailure(Exception e) {
-                                            if (isShardNotAvailableException(e)) {
-                                                if (backoff.hasNext()) {
-                                                    logger.info("could not get token document [{}] for refresh, " +
-                                                        "retrying", tokenDocId);
-                                                    client.threadPool().schedule(
-                                                        () -> getTokenDocAsync(tokenDocId, this), backoff.next(), GENERIC);
-                                                } else {
-                                                    logger.warn("could not get token document [{}] for refresh after all retries",
-                                                        tokenDocId);
-                                                    onFailure.accept(invalidGrantException("could not refresh the requested token"));
-                                                }
+                                    }
+
+                                    @Override
+                                    public void onFailure(Exception e) {
+                                        if (isShardNotAvailableException(e)) {
+                                            if (backoff.hasNext()) {
+                                                logger.info("could not get token document [{}] for refresh, retrying", tokenDocId);
+                                                final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
+                                                        .preserveContext(() -> getTokenDocAsync(tokenDocId, this));
+                                                client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
                                             } else {
-                                                onFailure.accept(e);
+                                                logger.warn("could not get token document [{}] for refresh after all retries", tokenDocId);
+                                                onFailure.accept(invalidGrantException("could not refresh the requested token"));
                                             }
+                                        } else {
+                                            onFailure.accept(e);
                                         }
-                                    };
-                                    getTokenDocAsync(tokenDocId, getListener);
-                                } else {
-                                    logger.warn("version conflict while updating document [{}], no retries left", tokenDocId);
-                                    onFailure.accept(invalidGrantException("could not refresh the requested token"));
-                                }
+                                    }
+                                };
+                                getTokenDocAsync(tokenDocId, getListener);
                             } else if (isShardNotAvailableException(e)) {
                                 if (backoff.hasNext()) {
                                     logger.debug("failed to update the original token document [{}], retrying", tokenDocId);
-                                    client.threadPool().schedule(
-                                        () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff,
-                                            refreshRequested),
-                                        backoff.next(), GENERIC);
+                                    final Runnable retryWithContextRunnable = client.threadPool().getThreadContext().preserveContext(
+                                            () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff,
+                                                    refreshRequested));
+                                    client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
                                 } else {
                                     logger.warn("failed to update the original token document [{}], after all retries", tokenDocId);
                                     onFailure.accept(invalidGrantException("could not refresh the requested token"));