Browse Source

CreateClassLoaderEntitlement + extensions to parse logic (#117754) (#117978)

Lorenzo Dematté 10 months ago
parent
commit
36d8307abd

+ 16 - 0
libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/CreateClassLoaderEntitlement.java

@@ -0,0 +1,16 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.entitlement.runtime.policy;
+
+public class CreateClassLoaderEntitlement implements Entitlement {
+    @ExternalEntitlement
+    public CreateClassLoaderEntitlement() {}
+
+}

+ 51 - 30
libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyParser.java

@@ -19,22 +19,43 @@ import java.io.UncheckedIOException;
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
-
-import static org.elasticsearch.entitlement.runtime.policy.PolicyParserException.newPolicyParserException;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * A parser to parse policy files for entitlements.
  */
 public class PolicyParser {
 
-    protected static final String entitlementPackageName = Entitlement.class.getPackage().getName();
+    private static final Map<String, Class<?>> EXTERNAL_ENTITLEMENTS = Stream.of(FileEntitlement.class, CreateClassLoaderEntitlement.class)
+        .collect(Collectors.toUnmodifiableMap(PolicyParser::getEntitlementTypeName, Function.identity()));
 
     protected final XContentParser policyParser;
     protected final String policyName;
 
+    static String getEntitlementTypeName(Class<? extends Entitlement> entitlementClass) {
+        var entitlementClassName = entitlementClass.getSimpleName();
+
+        if (entitlementClassName.endsWith("Entitlement") == false) {
+            throw new IllegalArgumentException(
+                entitlementClassName + " is not a valid Entitlement class name. A valid class name must end with 'Entitlement'"
+            );
+        }
+
+        var strippedClassName = entitlementClassName.substring(0, entitlementClassName.indexOf("Entitlement"));
+        return Arrays.stream(strippedClassName.split("(?=\\p{Lu})"))
+            .filter(Predicate.not(String::isEmpty))
+            .map(s -> s.toLowerCase(Locale.ROOT))
+            .collect(Collectors.joining("_"));
+    }
+
     public PolicyParser(InputStream inputStream, String policyName) throws IOException {
         this.policyParser = YamlXContent.yamlXContent.createParser(XContentParserConfiguration.EMPTY, Objects.requireNonNull(inputStream));
         this.policyName = policyName;
@@ -67,18 +88,23 @@ public class PolicyParser {
             }
             List<Entitlement> entitlements = new ArrayList<>();
             while (policyParser.nextToken() != XContentParser.Token.END_ARRAY) {
-                if (policyParser.currentToken() != XContentParser.Token.START_OBJECT) {
-                    throw newPolicyParserException(scopeName, "expected object <entitlement type>");
-                }
-                if (policyParser.nextToken() != XContentParser.Token.FIELD_NAME) {
+                if (policyParser.currentToken() == XContentParser.Token.VALUE_STRING) {
+                    String entitlementType = policyParser.text();
+                    Entitlement entitlement = parseEntitlement(scopeName, entitlementType);
+                    entitlements.add(entitlement);
+                } else if (policyParser.currentToken() == XContentParser.Token.START_OBJECT) {
+                    if (policyParser.nextToken() != XContentParser.Token.FIELD_NAME) {
+                        throw newPolicyParserException(scopeName, "expected object <entitlement type>");
+                    }
+                    String entitlementType = policyParser.currentName();
+                    Entitlement entitlement = parseEntitlement(scopeName, entitlementType);
+                    entitlements.add(entitlement);
+                    if (policyParser.nextToken() != XContentParser.Token.END_OBJECT) {
+                        throw newPolicyParserException(scopeName, "expected closing object");
+                    }
+                } else {
                     throw newPolicyParserException(scopeName, "expected object <entitlement type>");
                 }
-                String entitlementType = policyParser.currentName();
-                Entitlement entitlement = parseEntitlement(scopeName, entitlementType);
-                entitlements.add(entitlement);
-                if (policyParser.nextToken() != XContentParser.Token.END_OBJECT) {
-                    throw newPolicyParserException(scopeName, "expected closing object");
-                }
             }
             return new Scope(scopeName, entitlements);
         } catch (IOException ioe) {
@@ -87,34 +113,29 @@ public class PolicyParser {
     }
 
     protected Entitlement parseEntitlement(String scopeName, String entitlementType) throws IOException {
-        Class<?> entitlementClass;
-        try {
-            entitlementClass = Class.forName(
-                entitlementPackageName
-                    + "."
-                    + Character.toUpperCase(entitlementType.charAt(0))
-                    + entitlementType.substring(1)
-                    + "Entitlement"
-            );
-        } catch (ClassNotFoundException cnfe) {
-            throw newPolicyParserException(scopeName, "unknown entitlement type [" + entitlementType + "]");
-        }
-        if (Entitlement.class.isAssignableFrom(entitlementClass) == false) {
+        Class<?> entitlementClass = EXTERNAL_ENTITLEMENTS.get(entitlementType);
+
+        if (entitlementClass == null) {
             throw newPolicyParserException(scopeName, "unknown entitlement type [" + entitlementType + "]");
         }
+
         Constructor<?> entitlementConstructor = entitlementClass.getConstructors()[0];
         ExternalEntitlement entitlementMetadata = entitlementConstructor.getAnnotation(ExternalEntitlement.class);
         if (entitlementMetadata == null) {
             throw newPolicyParserException(scopeName, "unknown entitlement type [" + entitlementType + "]");
         }
 
-        if (policyParser.nextToken() != XContentParser.Token.START_OBJECT) {
-            throw newPolicyParserException(scopeName, entitlementType, "expected entitlement parameters");
+        Class<?>[] parameterTypes = entitlementConstructor.getParameterTypes();
+        String[] parametersNames = entitlementMetadata.parameterNames();
+
+        if (parameterTypes.length != 0 || parametersNames.length != 0) {
+            if (policyParser.nextToken() != XContentParser.Token.START_OBJECT) {
+                throw newPolicyParserException(scopeName, entitlementType, "expected entitlement parameters");
+            }
         }
+
         Map<String, Object> parsedValues = policyParser.map();
 
-        Class<?>[] parameterTypes = entitlementConstructor.getParameterTypes();
-        String[] parametersNames = entitlementMetadata.parameterNames();
         Object[] parameterValues = new Object[parameterTypes.length];
         for (int parameterIndex = 0; parameterIndex < parameterTypes.length; ++parameterIndex) {
             String parameterName = parametersNames[parameterIndex];

+ 3 - 4
libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserFailureTests.java

@@ -12,7 +12,6 @@ package org.elasticsearch.entitlement.runtime.policy;
 import org.elasticsearch.test.ESTestCase;
 
 import java.io.ByteArrayInputStream;
-import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 
 public class PolicyParserFailureTests extends ESTestCase {
@@ -26,7 +25,7 @@ public class PolicyParserFailureTests extends ESTestCase {
         assertEquals("[1:1] policy parsing error for [test-failure-policy.yaml]: expected object <scope name>", ppe.getMessage());
     }
 
-    public void testEntitlementDoesNotExist() throws IOException {
+    public void testEntitlementDoesNotExist() {
         PolicyParserException ppe = expectThrows(PolicyParserException.class, () -> new PolicyParser(new ByteArrayInputStream("""
             entitlement-module-name:
               - does_not_exist: {}
@@ -38,7 +37,7 @@ public class PolicyParserFailureTests extends ESTestCase {
         );
     }
 
-    public void testEntitlementMissingParameter() throws IOException {
+    public void testEntitlementMissingParameter() {
         PolicyParserException ppe = expectThrows(PolicyParserException.class, () -> new PolicyParser(new ByteArrayInputStream("""
             entitlement-module-name:
               - file: {}
@@ -61,7 +60,7 @@ public class PolicyParserFailureTests extends ESTestCase {
         );
     }
 
-    public void testEntitlementExtraneousParameter() throws IOException {
+    public void testEntitlementExtraneousParameter() {
         PolicyParserException ppe = expectThrows(PolicyParserException.class, () -> new PolicyParser(new ByteArrayInputStream("""
             entitlement-module-name:
               - file:

+ 39 - 0
libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserTests.java

@@ -11,11 +11,31 @@ package org.elasticsearch.entitlement.runtime.policy;
 
 import org.elasticsearch.test.ESTestCase;
 
+import java.io.ByteArrayInputStream;
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.util.List;
 
+import static org.elasticsearch.test.LambdaMatchers.transformedMatch;
+import static org.hamcrest.Matchers.both;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+
 public class PolicyParserTests extends ESTestCase {
 
+    private static class TestWrongEntitlementName implements Entitlement {}
+
+    public void testGetEntitlementTypeName() {
+        assertEquals("create_class_loader", PolicyParser.getEntitlementTypeName(CreateClassLoaderEntitlement.class));
+
+        var ex = expectThrows(IllegalArgumentException.class, () -> PolicyParser.getEntitlementTypeName(TestWrongEntitlementName.class));
+        assertThat(
+            ex.getMessage(),
+            equalTo("TestWrongEntitlementName is not a valid Entitlement class name. A valid class name must end with 'Entitlement'")
+        );
+    }
+
     public void testPolicyBuilder() throws IOException {
         Policy parsedPolicy = new PolicyParser(PolicyParserTests.class.getResourceAsStream("test-policy.yaml"), "test-policy.yaml")
             .parsePolicy();
@@ -25,4 +45,23 @@ public class PolicyParserTests extends ESTestCase {
         );
         assertEquals(parsedPolicy, builtPolicy);
     }
+
+    public void testParseCreateClassloader() throws IOException {
+        Policy parsedPolicy = new PolicyParser(new ByteArrayInputStream("""
+            entitlement-module-name:
+              - create_class_loader
+            """.getBytes(StandardCharsets.UTF_8)), "test-policy.yaml").parsePolicy();
+        Policy builtPolicy = new Policy(
+            "test-policy.yaml",
+            List.of(new Scope("entitlement-module-name", List.of(new CreateClassLoaderEntitlement())))
+        );
+        assertThat(
+            parsedPolicy.scopes,
+            contains(
+                both(transformedMatch((Scope scope) -> scope.name, equalTo("entitlement-module-name"))).and(
+                    transformedMatch(scope -> scope.entitlements, contains(instanceOf(CreateClassLoaderEntitlement.class)))
+                )
+            )
+        );
+    }
 }