Quellcode durchsuchen

Automatically close idle connections in OIDC back-channel (#87773)

In some environment, the back-channel connection can be dropped
without sending a TCP RST to ES. When that happens, reusing the same
connection results into timeout error.

This PR adds a new http.connection_pool_ttl setting to control how long
a connection in the OIDC back-channel pool can be idle before it is
closed. This allows ES to more actively close idle connections to avoid
the timeout issue.

The new setting has a 3min default which means idle connections are
closed every 3 min if server response does not specify a shorter keep-alive.

Resolves: #75515
Yang Wang vor 3 Jahren
Ursprung
Commit
f075d505c5

+ 5 - 0
docs/changelog/87773.yaml

@@ -0,0 +1,5 @@
+pr: 87773
+summary: Automatically close idle connections in OIDC back-channel
+area: Security
+type: enhancement
+issues: []

+ 15 - 0
docs/reference/settings/security-settings.asciidoc

@@ -1858,6 +1858,21 @@ connections allowed per endpoint.
 Defaults to `200`.
 // end::oidc-http-max-endpoint-connections-tag[]
 
+// tag::oidc-http-connection-pool-ttl-tag[]
+`http.connection_pool_ttl` {ess-icon}::
+(<<static-cluster-setting,Static>>)
+Controls the behavior of the http client used for back-channel communication to
+the OpenID Connect Provider endpoints. Specifies the time-to-live of connections
+in the connection pool (default to 3 minutes). A connection is closed if it is
+idle for more than the specified timeout.
+
+The server can also set the `Keep-Alive` HTTP response header. The effective
+time-to-live value is the smaller value between this setting and the `Keep-Alive`
+reponse header. Configure this setting to `-1` to let the server dictate the value.
+If the header is not set by the server and the setting has value of `-1`,
+the time-to-live is infinite and connections never expire.
+// end::oidc-http-connection-pool-ttl-tag[]
+
 [discrete]
 [[ref-oidc-ssl-settings]]
 ===== OpenID Connect realm SSL settings

+ 8 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/oidc/OpenIdConnectRealmSettings.java

@@ -23,6 +23,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 
 public class OpenIdConnectRealmSettings {
@@ -212,6 +213,12 @@ public class OpenIdConnectRealmSettings {
         "http.max_endpoint_connections",
         key -> Setting.intSetting(key, 200, Setting.Property.NodeScope)
     );
+
+    public static final Setting.AffixSetting<TimeValue> HTTP_CONNECTION_POOL_TTL = Setting.affixKeySetting(
+        RealmSettings.realmSettingPrefix(TYPE),
+        "http.connection_pool_ttl",
+        key -> Setting.timeSetting(key, new TimeValue(3, TimeUnit.MINUTES), Setting.Property.NodeScope)
+    );
     public static final Setting.AffixSetting<String> HTTP_PROXY_HOST = Setting.affixKeySetting(
         RealmSettings.realmSettingPrefix(TYPE),
         "http.proxy.host",
@@ -307,6 +314,7 @@ public class OpenIdConnectRealmSettings {
             HTTP_SOCKET_TIMEOUT,
             HTTP_MAX_CONNECTIONS,
             HTTP_MAX_ENDPOINT_CONNECTIONS,
+            HTTP_CONNECTION_POOL_TTL,
             HTTP_PROXY_HOST,
             HTTP_PROXY_PORT,
             HTTP_PROXY_SCHEME,

+ 1 - 0
x-pack/plugin/security/qa/smoke-test-all-realms/build.gradle

@@ -76,6 +76,7 @@ testClusters.matching { it.name == 'javaRestTest' }.configureEach {
   setting 'xpack.security.authc.realms.oidc.openid7.op.authorization_endpoint', 'https://op.example.com/auth'
   setting 'xpack.security.authc.realms.oidc.openid7.op.jwkset_path', 'oidc-jwkset.json'
   setting 'xpack.security.authc.realms.oidc.openid7.claims.principal', 'sub'
+  setting 'xpack.security.authc.realms.oidc.openid7.http.connection_pool_ttl', '1m'
   keystore 'xpack.security.authc.realms.oidc.openid7.rp.client_secret', 'this-is-my-secret'
   //  - JWT (works)
   setting 'xpack.security.authc.realms.jwt.jwt8.order', '8'

+ 32 - 1
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java

@@ -58,8 +58,10 @@ import org.apache.http.client.methods.HttpPost;
 import org.apache.http.concurrent.FutureCallback;
 import org.apache.http.config.Registry;
 import org.apache.http.config.RegistryBuilder;
+import org.apache.http.conn.ConnectionKeepAliveStrategy;
 import org.apache.http.entity.ContentType;
 import org.apache.http.impl.auth.BasicScheme;
+import org.apache.http.impl.client.DefaultConnectionKeepAliveStrategy;
 import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
 import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
 import org.apache.http.impl.nio.client.HttpAsyncClients;
@@ -114,6 +116,7 @@ import javax.net.ssl.SSLContext;
 
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings.ALLOWED_CLOCK_SKEW;
+import static org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings.HTTP_CONNECTION_POOL_TTL;
 import static org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings.HTTP_CONNECTION_READ_TIMEOUT;
 import static org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings.HTTP_CONNECT_TIMEOUT;
 import static org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings.HTTP_MAX_CONNECTIONS;
@@ -705,9 +708,11 @@ public class OpenIdConnectAuthenticator {
                     .setConnectionRequestTimeout(Math.toIntExact(realmConfig.getSetting(HTTP_CONNECTION_READ_TIMEOUT).getSeconds()))
                     .setSocketTimeout(Math.toIntExact(realmConfig.getSetting(HTTP_SOCKET_TIMEOUT).getMillis()))
                     .build();
+
                 HttpAsyncClientBuilder httpAsyncClientBuilder = HttpAsyncClients.custom()
                     .setConnectionManager(connectionManager)
-                    .setDefaultRequestConfig(requestConfig);
+                    .setDefaultRequestConfig(requestConfig)
+                    .setKeepAliveStrategy(getKeepAliveStrategy());
                 if (realmConfig.hasSetting(HTTP_PROXY_HOST)) {
                     httpAsyncClientBuilder.setProxy(
                         new HttpHost(
@@ -726,6 +731,32 @@ public class OpenIdConnectAuthenticator {
         }
     }
 
+    // Package private for testing
+    CloseableHttpAsyncClient getHttpClient() {
+        return httpClient;
+    }
+
+    // Package private for testing
+    ConnectionKeepAliveStrategy getKeepAliveStrategy() {
+        final long userConfiguredKeepAlive = realmConfig.getSetting(HTTP_CONNECTION_POOL_TTL).millis();
+        return (response, context) -> {
+            var serverKeepAlive = DefaultConnectionKeepAliveStrategy.INSTANCE.getKeepAliveDuration(response, context);
+            long actualKeepAlive;
+            if (serverKeepAlive <= -1) {
+                actualKeepAlive = userConfiguredKeepAlive;
+            } else if (userConfiguredKeepAlive <= -1) {
+                actualKeepAlive = serverKeepAlive;
+            } else {
+                actualKeepAlive = Math.min(serverKeepAlive, userConfiguredKeepAlive);
+            }
+            if (actualKeepAlive < -1) {
+                actualKeepAlive = -1;
+            }
+            LOGGER.debug("effective HTTP connection keep-alive: [{}]ms", actualKeepAlive);
+            return actualKeepAlive;
+        };
+    }
+
     /*
      * Creates an {@link IDTokenValidator} based on the current Relying Party configuration
      */

+ 228 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticatorTests.java

@@ -44,14 +44,22 @@ import com.nimbusds.openid.connect.sdk.Nonce;
 import com.nimbusds.openid.connect.sdk.claims.AccessTokenHash;
 import com.nimbusds.openid.connect.sdk.validators.IDTokenValidator;
 import com.nimbusds.openid.connect.sdk.validators.InvalidHashException;
+import com.sun.net.httpserver.HttpServer;
 
+import org.apache.http.HeaderIterator;
 import org.apache.http.HttpResponse;
 import org.apache.http.HttpVersion;
 import org.apache.http.ProtocolVersion;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.concurrent.FutureCallback;
+import org.apache.http.conn.ConnectionKeepAliveStrategy;
 import org.apache.http.entity.ContentType;
 import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager;
+import org.apache.http.message.BasicHeader;
 import org.apache.http.message.BasicHttpResponse;
 import org.apache.http.message.BasicStatusLine;
+import org.apache.http.protocol.HTTP;
 import org.apache.logging.log4j.Level;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
@@ -59,6 +67,7 @@ import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.logging.Loggers;
+import org.elasticsearch.common.network.InetAddresses;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
@@ -66,15 +75,20 @@ import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.TestEnvironment;
+import org.elasticsearch.mocksocket.MockHttpServer;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.MockLogAppender;
 import org.elasticsearch.test.TestMatchers;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
+import org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings;
 import org.elasticsearch.xpack.core.ssl.SSLService;
 import org.junit.After;
 import org.junit.Before;
 import org.mockito.Mockito;
 
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
 import java.net.URI;
 import java.net.URISyntaxException;
 import java.nio.charset.StandardCharsets;
@@ -89,18 +103,24 @@ import java.security.interfaces.RSAPublicKey;
 import java.util.Base64;
 import java.util.Collections;
 import java.util.Date;
+import java.util.Iterator;
+import java.util.List;
 import java.util.Map;
 import java.util.UUID;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
 
 import javax.crypto.SecretKey;
 import javax.crypto.spec.SecretKeySpec;
 
 import static java.time.Instant.now;
+import static org.elasticsearch.xpack.core.security.authc.RealmSettings.getFullSettingKey;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -1006,6 +1026,214 @@ public class OpenIdConnectAuthenticatorTests extends OpenIdConnectTestCase {
         }
     }
 
+    public void testHttpClientConnectionTtlBehaviour() throws URISyntaxException, IllegalAccessException, InterruptedException,
+        IOException {
+        // Create an internal HTTP server, the expectation is: For 2 consecutive HTTP requests, the client port should be different
+        // because the client should not reuse the same connection after 1s
+        final HttpServer httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
+        httpServer.start();
+
+        final AtomicReference<Integer> firstClientPort = new AtomicReference<>(null);
+        final AtomicReference<Boolean> portTested = new AtomicReference<>(false);
+        httpServer.createContext("/", exchange -> {
+            try {
+                final int currentPort = exchange.getRemoteAddress().getPort();
+                // Either set the first port number, otherwise the current (2nd) port number should be different from the 1st one
+                if (false == firstClientPort.compareAndSet(null, currentPort)) {
+                    assertThat(currentPort, not(equalTo(firstClientPort.get())));
+                    portTested.set(true);
+                }
+                final byte[] bytes = randomByteArrayOfLength(2);
+                exchange.sendResponseHeaders(200, bytes.length);
+                exchange.getResponseBody().write(bytes);
+            } finally {
+                exchange.close();
+            }
+        });
+
+        final InetSocketAddress address = httpServer.getAddress();
+        final URI uri = new URI("http://" + InetAddresses.toUriString(address.getAddress()) + ":" + address.getPort());
+
+        // Authenticator with a short TTL
+        final RealmConfig config = buildConfig(
+            getBasicRealmSettings().put(getFullSettingKey(REALM_NAME, OpenIdConnectRealmSettings.HTTP_CONNECTION_POOL_TTL), "1s").build(),
+            threadContext
+        );
+        authenticator = new OpenIdConnectAuthenticator(config, getOpConfig(), getDefaultRpConfig(), new SSLService(env), null);
+
+        // In addition, capture logs to show that kept alive (TTL) is honored
+        final Logger logger = LogManager.getLogger(PoolingNHttpClientConnectionManager.class);
+        final MockLogAppender appender = new MockLogAppender();
+        appender.start();
+        Loggers.addAppender(logger, appender);
+        Loggers.setLevel(logger, Level.DEBUG);
+        try {
+            appender.addExpectation(
+                new MockLogAppender.PatternSeenEventExpectation(
+                    "log",
+                    logger.getName(),
+                    Level.DEBUG,
+                    ".*Connection .* can be kept alive for 1.0 seconds"
+                )
+            );
+            // Issue two requests to verify the 2nd request do not reuse the 1st request's connection
+            for (int i = 0; i < 2; i++) {
+                final CountDownLatch latch = new CountDownLatch(1);
+                authenticator.getHttpClient().execute(new HttpGet(uri), new FutureCallback<>() {
+                    @Override
+                    public void completed(HttpResponse result) {
+                        latch.countDown();
+                    }
+
+                    @Override
+                    public void failed(Exception ex) {
+                        assert false;
+                    }
+
+                    @Override
+                    public void cancelled() {
+                        assert false;
+                    }
+                });
+                latch.await();
+                Thread.sleep(1500);
+            }
+            appender.assertAllExpectationsMatched();
+            assertThat(portTested.get(), is(true));
+        } finally {
+            Loggers.removeAppender(logger, appender);
+            appender.stop();
+            Loggers.setLevel(logger, (Level) null);
+            authenticator.close();
+            httpServer.stop(1);
+        }
+    }
+
+    public void testKeepAliveStrategy() throws URISyntaxException, IllegalAccessException {
+        // Neither server nor client has explicit configuration
+        doTestKeepAliveStrategy(null, null, 180_000L);
+
+        // Client explicitly configures for 100s
+        doTestKeepAliveStrategy(null, "100", 100_000L);
+
+        // Server explicitly configures for 400s, but client's default is 180s
+        doTestKeepAliveStrategy("400", null, 180_000L);
+
+        // Server explicitly configures for 120s
+        doTestKeepAliveStrategy("120", null, 120_000L);
+
+        // Both server and client explicitly configures it
+        doTestKeepAliveStrategy("120", "90", 90_000L);
+
+        // Both server and client explicitly configures it
+        doTestKeepAliveStrategy("80", "90", 80_000L);
+
+        // Server configures negative value
+        doTestKeepAliveStrategy(String.valueOf(randomIntBetween(-100, -1)), null, 180_000L);
+        doTestKeepAliveStrategy(String.valueOf(randomIntBetween(-100, -1)), "400", 400_000L);
+
+        // Client configures negative value, -1 is the only negative number accepted by timeSetting
+        doTestKeepAliveStrategy(null, "-1", -1L);
+        doTestKeepAliveStrategy("30", "-1", 30_000L);
+
+        // Both server and client explicitly configures negative values
+        doTestKeepAliveStrategy(String.valueOf(randomIntBetween(-100, -1)), "-1", -1L);
+
+        // Extra randomization
+        final int serverTtlInSeconds;
+        if (randomBoolean()) {
+            serverTtlInSeconds = randomIntBetween(-1, 300);
+        } else {
+            // Server may not set the response header
+            serverTtlInSeconds = -1;
+        }
+
+        final int clientTtlInSeconds;
+        if (randomBoolean()) {
+            clientTtlInSeconds = randomIntBetween(-1, 300);
+        } else {
+            clientTtlInSeconds = 180; // default 180s
+        }
+
+        final int effectiveTtlInSeconds;
+        if (serverTtlInSeconds <= -1) {
+            effectiveTtlInSeconds = clientTtlInSeconds;
+        } else if (clientTtlInSeconds <= -1) {
+            effectiveTtlInSeconds = serverTtlInSeconds;
+        } else {
+            effectiveTtlInSeconds = Math.min(serverTtlInSeconds, clientTtlInSeconds);
+        }
+        final long effectiveTtlInMs = effectiveTtlInSeconds <= -1 ? -1L : effectiveTtlInSeconds * 1000L;
+
+        doTestKeepAliveStrategy(
+            serverTtlInSeconds == -1 ? randomFrom(String.valueOf(serverTtlInSeconds), null) : String.valueOf(serverTtlInSeconds),
+            clientTtlInSeconds == 180 ? randomFrom(String.valueOf(clientTtlInSeconds), null) : String.valueOf(clientTtlInSeconds),
+            effectiveTtlInMs
+        );
+    }
+
+    public void doTestKeepAliveStrategy(String serverTtlInSeconds, String clientTtlInSeconds, long effectiveTtlInMs)
+        throws URISyntaxException, IllegalAccessException {
+        final HttpResponse httpResponse = mock(HttpResponse.class);
+        final Iterator<BasicHeader> iterator;
+        if (serverTtlInSeconds != null) {
+            iterator = List.of(new BasicHeader("Keep-Alive", "timeout=" + serverTtlInSeconds)).iterator();
+        } else {
+            // Server may not set the response header
+            iterator = Collections.emptyIterator();
+        }
+        when(httpResponse.headerIterator(HTTP.CONN_KEEP_ALIVE)).thenReturn(new HeaderIterator() {
+            @Override
+            public boolean hasNext() {
+                return iterator.hasNext();
+            }
+
+            @Override
+            public org.apache.http.Header nextHeader() {
+                return iterator.next();
+            }
+
+            @Override
+            public Object next() {
+                return iterator.next();
+            }
+        });
+
+        final Settings.Builder settingsBuilder = getBasicRealmSettings();
+        if (clientTtlInSeconds != null) {
+            settingsBuilder.put(
+                getFullSettingKey(REALM_NAME, OpenIdConnectRealmSettings.HTTP_CONNECTION_POOL_TTL),
+                clientTtlInSeconds + "s"
+            );
+        }
+        final RealmConfig config = buildConfig(settingsBuilder.build(), threadContext);
+        authenticator = new OpenIdConnectAuthenticator(config, getOpConfig(), getDefaultRpConfig(), new SSLService(env), null);
+
+        final Logger logger = LogManager.getLogger(OpenIdConnectAuthenticator.class);
+        final MockLogAppender appender = new MockLogAppender();
+        appender.start();
+        Loggers.addAppender(logger, appender);
+        Loggers.setLevel(logger, Level.DEBUG);
+        try {
+            appender.addExpectation(
+                new MockLogAppender.SeenEventExpectation(
+                    "log",
+                    logger.getName(),
+                    Level.DEBUG,
+                    "effective HTTP connection keep-alive: [" + effectiveTtlInMs + "]ms"
+                )
+            );
+            final ConnectionKeepAliveStrategy keepAliveStrategy = authenticator.getKeepAliveStrategy();
+            assertThat(keepAliveStrategy.getKeepAliveDuration(httpResponse, null), equalTo(effectiveTtlInMs));
+            appender.assertAllExpectationsMatched();
+        } finally {
+            Loggers.removeAppender(logger, appender);
+            appender.stop();
+            Loggers.setLevel(logger, (Level) null);
+            authenticator.close();
+        }
+    }
+
     private OpenIdConnectProviderConfiguration getOpConfig() throws URISyntaxException {
         return new OpenIdConnectProviderConfiguration(
             new Issuer("https://op.example.com"),