Browse Source

Enhance `Ec2ImdsHttpHandler` (#119334) (#119383)

- Require IMDSv1 if using alternative endpoints (i.e. ECS)
- Forbid profile name lookup with alternative endpoints
- Add token TTL header for IMDSv2
- Add support for instance-identity docs
David Turner 9 months ago
parent
commit
3fe76378ec

+ 28 - 3
test/fixtures/ec2-imds-fixture/src/main/java/fixture/aws/imds/Ec2ImdsHttpHandler.java

@@ -14,8 +14,11 @@ import com.sun.net.httpserver.HttpHandler;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.SuppressForbidden;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.ToXContent;
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
@@ -45,21 +48,38 @@ public class Ec2ImdsHttpHandler implements HttpHandler {
 
     private final BiConsumer<String, String> newCredentialsConsumer;
     private final Map<String, String> instanceAddresses;
-    private final Set<String> validCredentialsEndpoints = ConcurrentCollections.newConcurrentSet();
+    private final Set<String> validCredentialsEndpoints;
+    private final boolean dynamicProfileNames;
     private final Supplier<String> availabilityZoneSupplier;
+    @Nullable // if instance identity document not available
+    private final ToXContent instanceIdentityDocument;
 
     public Ec2ImdsHttpHandler(
         Ec2ImdsVersion ec2ImdsVersion,
         BiConsumer<String, String> newCredentialsConsumer,
         Collection<String> alternativeCredentialsEndpoints,
         Supplier<String> availabilityZoneSupplier,
+        @Nullable ToXContent instanceIdentityDocument,
         Map<String, String> instanceAddresses
     ) {
         this.ec2ImdsVersion = Objects.requireNonNull(ec2ImdsVersion);
         this.newCredentialsConsumer = Objects.requireNonNull(newCredentialsConsumer);
         this.instanceAddresses = instanceAddresses;
-        this.validCredentialsEndpoints.addAll(alternativeCredentialsEndpoints);
+
+        if (alternativeCredentialsEndpoints.isEmpty()) {
+            dynamicProfileNames = true;
+            validCredentialsEndpoints = ConcurrentCollections.newConcurrentSet();
+        } else if (ec2ImdsVersion == Ec2ImdsVersion.V2) {
+            throw new IllegalArgumentException(
+                Strings.format("alternative credentials endpoints %s requires IMDSv1", alternativeCredentialsEndpoints)
+            );
+        } else {
+            dynamicProfileNames = false;
+            validCredentialsEndpoints = Set.copyOf(alternativeCredentialsEndpoints);
+        }
+
         this.availabilityZoneSupplier = availabilityZoneSupplier;
+        this.instanceIdentityDocument = instanceIdentityDocument;
     }
 
     @Override
@@ -78,6 +98,8 @@ public class Ec2ImdsHttpHandler implements HttpHandler {
                         validImdsTokens.add(token);
                         final var responseBody = token.getBytes(StandardCharsets.UTF_8);
                         exchange.getResponseHeaders().add("Content-Type", "text/plain");
+                        exchange.getResponseHeaders()
+                            .add("x-aws-ec2-metadata-token-ttl-seconds", Long.toString(TimeValue.timeValueDays(1).seconds()));
                         exchange.sendResponseHeaders(RestStatus.OK.getStatus(), responseBody.length);
                         exchange.getResponseBody().write(responseBody);
                     }
@@ -98,7 +120,7 @@ public class Ec2ImdsHttpHandler implements HttpHandler {
             }
 
             if ("GET".equals(requestMethod)) {
-                if (path.equals(IMDS_SECURITY_CREDENTIALS_PATH)) {
+                if (path.equals(IMDS_SECURITY_CREDENTIALS_PATH) && dynamicProfileNames) {
                     final var profileName = randomIdentifier();
                     validCredentialsEndpoints.add(IMDS_SECURITY_CREDENTIALS_PATH + profileName);
                     sendStringResponse(exchange, profileName);
@@ -107,6 +129,9 @@ public class Ec2ImdsHttpHandler implements HttpHandler {
                     final var availabilityZone = availabilityZoneSupplier.get();
                     sendStringResponse(exchange, availabilityZone);
                     return;
+                } else if (instanceIdentityDocument != null && path.equals("/latest/dynamic/instance-identity/document")) {
+                    sendStringResponse(exchange, Strings.toString(instanceIdentityDocument));
+                    return;
                 } else if (validCredentialsEndpoints.contains(path)) {
                     final String accessKey = randomIdentifier();
                     final String sessionToken = randomIdentifier();

+ 7 - 0
test/fixtures/ec2-imds-fixture/src/main/java/fixture/aws/imds/Ec2ImdsServiceBuilder.java

@@ -10,6 +10,7 @@
 package fixture.aws.imds;
 
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.ToXContent;
 
 import java.util.Collection;
 import java.util.HashMap;
@@ -24,6 +25,7 @@ public class Ec2ImdsServiceBuilder {
     private BiConsumer<String, String> newCredentialsConsumer = Ec2ImdsServiceBuilder::rejectNewCredentials;
     private Collection<String> alternativeCredentialsEndpoints = Set.of();
     private Supplier<String> availabilityZoneSupplier = Ec2ImdsServiceBuilder::rejectAvailabilityZone;
+    private ToXContent instanceIdentityDocument = null;
     private final Map<String, String> instanceAddresses = new HashMap<>();
 
     public Ec2ImdsServiceBuilder(Ec2ImdsVersion ec2ImdsVersion) {
@@ -64,8 +66,13 @@ public class Ec2ImdsServiceBuilder {
             newCredentialsConsumer,
             alternativeCredentialsEndpoints,
             availabilityZoneSupplier,
+            instanceIdentityDocument,
             Map.copyOf(instanceAddresses)
         );
     }
 
+    public Ec2ImdsServiceBuilder instanceIdentityDocument(ToXContent instanceIdentityDocument) {
+        this.instanceIdentityDocument = instanceIdentityDocument;
+        return this;
+    }
 }

+ 64 - 18
test/fixtures/ec2-imds-fixture/src/test/java/fixture/aws/imds/Ec2ImdsHttpHandlerTests.java

@@ -52,16 +52,13 @@ public class Ec2ImdsHttpHandlerTests extends ESTestCase {
         assertTrue(Strings.hasText(profileName));
 
         final var credentialsResponse = handleRequest(handler, "GET", SECURITY_CREDENTIALS_URI + profileName);
-        assertEquals(RestStatus.OK, credentialsResponse.status());
 
         assertThat(generatedCredentials, aMapWithSize(1));
-        final var accessKey = generatedCredentials.keySet().iterator().next();
-        final var sessionToken = generatedCredentials.values().iterator().next();
-
-        final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
-        assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
-        assertEquals(accessKey, responseMap.get("AccessKeyId"));
-        assertEquals(sessionToken, responseMap.get("Token"));
+        assertValidCredentialsResponse(
+            credentialsResponse,
+            generatedCredentials.keySet().iterator().next(),
+            generatedCredentials.values().iterator().next()
+        );
     }
 
     public void testImdsV2Disabled() {
@@ -78,6 +75,7 @@ public class Ec2ImdsHttpHandlerTests extends ESTestCase {
 
         final var tokenResponse = handleRequest(handler, "PUT", "/latest/api/token");
         assertEquals(RestStatus.OK, tokenResponse.status());
+        assertEquals(List.of("86400" /* seconds in a day */), tokenResponse.responseHeaders().get("x-aws-ec2-metadata-token-ttl-seconds"));
         final var token = tokenResponse.body().utf8ToString();
 
         final var roleResponse = checkImdsV2GetRequest(handler, SECURITY_CREDENTIALS_URI, token);
@@ -86,16 +84,13 @@ public class Ec2ImdsHttpHandlerTests extends ESTestCase {
         assertTrue(Strings.hasText(profileName));
 
         final var credentialsResponse = checkImdsV2GetRequest(handler, SECURITY_CREDENTIALS_URI + profileName, token);
-        assertEquals(RestStatus.OK, credentialsResponse.status());
 
         assertThat(generatedCredentials, aMapWithSize(1));
-        final var accessKey = generatedCredentials.keySet().iterator().next();
-        final var sessionToken = generatedCredentials.values().iterator().next();
-
-        final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
-        assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
-        assertEquals(accessKey, responseMap.get("AccessKeyId"));
-        assertEquals(sessionToken, responseMap.get("Token"));
+        assertValidCredentialsResponse(
+            credentialsResponse,
+            generatedCredentials.keySet().iterator().next(),
+            generatedCredentials.values().iterator().next()
+        );
     }
 
     public void testAvailabilityZone() {
@@ -113,7 +108,54 @@ public class Ec2ImdsHttpHandlerTests extends ESTestCase {
         assertEquals(generatedAvailabilityZones, Set.of(availabilityZone));
     }
 
-    private record TestHttpResponse(RestStatus status, BytesReference body) {}
+    public void testAlternativeCredentialsEndpoint() throws IOException {
+        expectThrows(
+            IllegalArgumentException.class,
+            new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V2).alternativeCredentialsEndpoints(Set.of("/should-not-work"))::buildHandler
+        );
+
+        final var alternativePaths = randomList(1, 5, () -> "/" + randomIdentifier());
+        final Map<String, String> generatedCredentials = new HashMap<>();
+
+        final var handler = new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V1).alternativeCredentialsEndpoints(alternativePaths)
+            .newCredentialsConsumer(generatedCredentials::put)
+            .buildHandler();
+
+        final var credentialsResponse = handleRequest(handler, "GET", randomFrom(alternativePaths));
+
+        assertThat(generatedCredentials, aMapWithSize(1));
+        assertValidCredentialsResponse(
+            credentialsResponse,
+            generatedCredentials.keySet().iterator().next(),
+            generatedCredentials.values().iterator().next()
+        );
+    }
+
+    private static void assertValidCredentialsResponse(TestHttpResponse credentialsResponse, String accessKey, String sessionToken)
+        throws IOException {
+        assertEquals(RestStatus.OK, credentialsResponse.status());
+        final var responseMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), credentialsResponse.body().streamInput(), false);
+        assertEquals(Set.of("AccessKeyId", "Expiration", "RoleArn", "SecretAccessKey", "Token"), responseMap.keySet());
+        assertEquals(accessKey, responseMap.get("AccessKeyId"));
+        assertEquals(sessionToken, responseMap.get("Token"));
+    }
+
+    public void testInstanceIdentityDocument() {
+        final Set<String> generatedRegions = new HashSet<>();
+        final var handler = new Ec2ImdsServiceBuilder(Ec2ImdsVersion.V1).instanceIdentityDocument((builder, params) -> {
+            final var newRegion = randomIdentifier();
+            generatedRegions.add(newRegion);
+            return builder.field("region", newRegion);
+        }).buildHandler();
+
+        final var instanceIdentityResponse = handleRequest(handler, "GET", "/latest/dynamic/instance-identity/document");
+        assertEquals(RestStatus.OK, instanceIdentityResponse.status());
+        final var instanceIdentityString = instanceIdentityResponse.body().utf8ToString();
+
+        assertEquals(Strings.format("{\"region\":\"%s\"}", generatedRegions.iterator().next()), instanceIdentityString);
+    }
+
+    private record TestHttpResponse(RestStatus status, Headers responseHeaders, BytesReference body) {}
 
     private static TestHttpResponse checkImdsV2GetRequest(Ec2ImdsHttpHandler handler, String uri, String token) {
         final var unauthorizedResponse = handleRequest(handler, "GET", uri, null);
@@ -145,7 +187,11 @@ public class Ec2ImdsHttpHandlerTests extends ESTestCase {
             fail(e);
         }
         assertNotEquals(0, httpExchange.getResponseCode());
-        return new TestHttpResponse(RestStatus.fromCode(httpExchange.getResponseCode()), httpExchange.getResponseBodyContents());
+        return new TestHttpResponse(
+            RestStatus.fromCode(httpExchange.getResponseCode()),
+            httpExchange.getResponseHeaders(),
+            httpExchange.getResponseBodyContents()
+        );
     }
 
     private static class TestHttpExchange extends HttpExchange {