|  | @@ -9,6 +9,7 @@
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  package org.elasticsearch.entitlement.runtime.policy;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.util.ArrayUtils;
 | 
	
		
			
				|  |  |  import org.elasticsearch.entitlement.runtime.policy.entitlements.Entitlement;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.ESTestCase;
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -17,6 +18,7 @@ import java.net.URISyntaxException;
 | 
	
		
			
				|  |  |  import java.nio.file.Path;
 | 
	
		
			
				|  |  |  import java.security.CodeSource;
 | 
	
		
			
				|  |  |  import java.security.ProtectionDomain;
 | 
	
		
			
				|  |  | +import java.util.Arrays;
 | 
	
		
			
				|  |  |  import java.util.Collection;
 | 
	
		
			
				|  |  |  import java.util.List;
 | 
	
		
			
				|  |  |  import java.util.Map;
 | 
	
	
		
			
				|  | @@ -29,6 +31,7 @@ public class TestPolicyManager extends PolicyManager {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      boolean isActive;
 | 
	
		
			
				|  |  |      boolean isTriviallyAllowingTestCode;
 | 
	
		
			
				|  |  | +    String[] entitledTestPackages = TEST_FRAMEWORK_PACKAGE_PREFIXES;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      /**
 | 
	
		
			
				|  |  |       * We don't have modules in tests, so we can't use the inherited map of entitlements per module.
 | 
	
	
		
			
				|  | @@ -60,6 +63,16 @@ public class TestPolicyManager extends PolicyManager {
 | 
	
		
			
				|  |  |          this.isTriviallyAllowingTestCode = newValue;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    public void setEntitledTestPackages(String... entitledTestPackages) {
 | 
	
		
			
				|  |  | +        assertNoRedundantPrefixes(TEST_FRAMEWORK_PACKAGE_PREFIXES, entitledTestPackages, false);
 | 
	
		
			
				|  |  | +        if (entitledTestPackages.length > 1) {
 | 
	
		
			
				|  |  | +            assertNoRedundantPrefixes(entitledTestPackages, entitledTestPackages, true);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        String[] packages = ArrayUtils.concat(this.entitledTestPackages, entitledTestPackages);
 | 
	
		
			
				|  |  | +        Arrays.sort(packages);
 | 
	
		
			
				|  |  | +        this.entitledTestPackages = packages;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      /**
 | 
	
		
			
				|  |  |       * Called between tests so each test is not affected by prior tests
 | 
	
		
			
				|  |  |       */
 | 
	
	
		
			
				|  | @@ -110,19 +123,47 @@ public class TestPolicyManager extends PolicyManager {
 | 
	
		
			
				|  |  |              && (requestingClass.getName().contains("Test") == false);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    @Deprecated // TODO: reevaluate whether we want this.
 | 
	
		
			
				|  |  | -    // If we can simply check for dependencies the gradle worker has that aren't
 | 
	
		
			
				|  |  | -    // declared in the gradle config (namely org.gradle) that would be simpler.
 | 
	
		
			
				|  |  |      private boolean isTestFrameworkClass(Class<?> requestingClass) {
 | 
	
		
			
				|  |  | -        String packageName = requestingClass.getPackageName();
 | 
	
		
			
				|  |  | -        for (String prefix : TEST_FRAMEWORK_PACKAGE_PREFIXES) {
 | 
	
		
			
				|  |  | -            if (packageName.startsWith(prefix)) {
 | 
	
		
			
				|  |  | +        return isTestFrameworkClass(entitledTestPackages, requestingClass.getPackageName());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    // no redundant entries allowed, see assertNoRedundantPrefixes
 | 
	
		
			
				|  |  | +    static boolean isTestFrameworkClass(String[] sortedPrefixes, String packageName) {
 | 
	
		
			
				|  |  | +        int idx = Arrays.binarySearch(sortedPrefixes, packageName);
 | 
	
		
			
				|  |  | +        if (idx >= 0) {
 | 
	
		
			
				|  |  | +            return true;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        idx = -idx - 2; // candidate package index (insertion point - 1)
 | 
	
		
			
				|  |  | +        if (idx >= 0 && idx < sortedPrefixes.length) {
 | 
	
		
			
				|  |  | +            String candidate = sortedPrefixes[idx];
 | 
	
		
			
				|  |  | +            if (packageName.startsWith(candidate)
 | 
	
		
			
				|  |  | +                && (packageName.length() == candidate.length() || packageName.charAt(candidate.length()) == '.')) {
 | 
	
		
			
				|  |  |                  return true;
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          return false;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    private static boolean isNotPrefixMatch(String name, String prefix, boolean discardExactMatch) {
 | 
	
		
			
				|  |  | +        assert prefix.endsWith(".") == false : "Invalid package prefix ending with '.' [" + prefix + "]";
 | 
	
		
			
				|  |  | +        if (name == prefix || name.startsWith(prefix)) {
 | 
	
		
			
				|  |  | +            if (name.length() == prefix.length()) {
 | 
	
		
			
				|  |  | +                return discardExactMatch;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            return false == (name.length() > prefix.length() && name.charAt(prefix.length()) == '.');
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        return true;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    static void assertNoRedundantPrefixes(String[] setA, String[] setB, boolean discardExactMatch) {
 | 
	
		
			
				|  |  | +        for (String a : setA) {
 | 
	
		
			
				|  |  | +            for (String b : setB) {
 | 
	
		
			
				|  |  | +                assert isNotPrefixMatch(a, b, discardExactMatch) && isNotPrefixMatch(b, a, discardExactMatch)
 | 
	
		
			
				|  |  | +                    : "Redundant prefix entries: [" + a + ", " + b + "]";
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      private boolean isTestCode(Class<?> requestingClass) {
 | 
	
		
			
				|  |  |          // TODO: Cache this? It's expensive
 | 
	
		
			
				|  |  |          for (Class<?> candidate = requireNonNull(requestingClass); candidate != null; candidate = candidate.getDeclaringClass()) {
 | 
	
	
		
			
				|  | @@ -163,6 +204,10 @@ public class TestPolicyManager extends PolicyManager {
 | 
	
		
			
				|  |  |          "org.bouncycastle.jsse.provider" // Used in test code if FIPS is enabled, support more fine-grained config in ES-12128
 | 
	
		
			
				|  |  |      };
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    static {
 | 
	
		
			
				|  |  | +        Arrays.sort(TEST_FRAMEWORK_PACKAGE_PREFIXES);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  |      protected ModuleEntitlements getEntitlements(Class<?> requestingClass) {
 | 
	
		
			
				|  |  |          return classEntitlementsMap.computeIfAbsent(requestingClass, this::computeEntitlements);
 |