Browse Source

[Entitlements] Refactor InstrumenterImpl tests (#117688) (#117791)

Following up
https://github.com/elastic/elasticsearch/pull/117332#discussion_r1856803255,
I refactored `InstrumenterImpl` tests, splitting them into 2 suites:  -
`SyntheticInstrumenterImplTests`, which tests the mechanics of
instrumentation using ad-hoc test cases. This should see little change
now that we have our Instrumenter working as intended -
`InstrumenterImplTests`, which is back to its original intent to make
sure (1) the right arguments make it all the way to the check methods,
and (2) if the check method throws, that exception correctly bubbles up
through the instrumented method.

The PR also includes a little change to `InstrumenterImpl`  construction
to clean it up a bit and make it more testable.
Lorenzo Dematté 10 months ago
parent
commit
2f64565c75

+ 7 - 21
libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java

@@ -9,7 +9,7 @@
 
 package org.elasticsearch.entitlement.instrumentation.impl;
 
-import org.elasticsearch.entitlement.instrumentation.CheckerMethod;
+import org.elasticsearch.entitlement.instrumentation.CheckMethod;
 import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
 import org.elasticsearch.entitlement.instrumentation.Instrumenter;
 import org.elasticsearch.entitlement.instrumentation.MethodKey;
@@ -20,37 +20,23 @@ import org.objectweb.asm.Opcodes;
 import org.objectweb.asm.Type;
 
 import java.io.IOException;
-import java.lang.reflect.Method;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
-import java.util.stream.Stream;
 
 public class InstrumentationServiceImpl implements InstrumentationService {
 
     @Override
-    public Instrumenter newInstrumenter(String classNameSuffix, Map<MethodKey, CheckerMethod> instrumentationMethods) {
-        return new InstrumenterImpl(classNameSuffix, instrumentationMethods);
-    }
-
-    /**
-     * @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline
-     */
-    public MethodKey methodKeyForTarget(Method targetMethod) {
-        Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod));
-        return new MethodKey(
-            Type.getInternalName(targetMethod.getDeclaringClass()),
-            targetMethod.getName(),
-            Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList()
-        );
+    public Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
+        return InstrumenterImpl.create(checkMethods);
     }
 
     @Override
-    public Map<MethodKey, CheckerMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException,
+    public Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException,
         IOException {
-        var methodsToInstrument = new HashMap<MethodKey, CheckerMethod>();
+        var methodsToInstrument = new HashMap<MethodKey, CheckMethod>();
         var checkerClass = Class.forName(entitlementCheckerClassName);
         var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
         ClassReader reader = new ClassReader(classFileInfo.bytecodes());
@@ -69,9 +55,9 @@ public class InstrumentationServiceImpl implements InstrumentationService {
                 var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);
 
                 var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
-                var checkerMethod = new CheckerMethod(Type.getInternalName(checkerClass), checkerMethodName, checkerParameterDescriptors);
+                var checkMethod = new CheckMethod(Type.getInternalName(checkerClass), checkerMethodName, checkerParameterDescriptors);
 
-                methodsToInstrument.put(methodToInstrument, checkerMethod);
+                methodsToInstrument.put(methodToInstrument, checkMethod);
 
                 return mv;
             }

+ 37 - 24
libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java

@@ -9,7 +9,7 @@
 
 package org.elasticsearch.entitlement.instrumentation.impl;
 
-import org.elasticsearch.entitlement.instrumentation.CheckerMethod;
+import org.elasticsearch.entitlement.instrumentation.CheckMethod;
 import org.elasticsearch.entitlement.instrumentation.Instrumenter;
 import org.elasticsearch.entitlement.instrumentation.MethodKey;
 import org.objectweb.asm.AnnotationVisitor;
@@ -37,9 +37,28 @@ import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL;
 
 public class InstrumenterImpl implements Instrumenter {
 
-    private static final String checkerClassDescriptor;
-    private static final String handleClass;
-    static {
+    private final String getCheckerClassMethodDescriptor;
+    private final String handleClass;
+
+    /**
+     * To avoid class name collisions during testing without an agent to replace classes in-place.
+     */
+    private final String classNameSuffix;
+    private final Map<MethodKey, CheckMethod> checkMethods;
+
+    InstrumenterImpl(
+        String handleClass,
+        String getCheckerClassMethodDescriptor,
+        String classNameSuffix,
+        Map<MethodKey, CheckMethod> checkMethods
+    ) {
+        this.handleClass = handleClass;
+        this.getCheckerClassMethodDescriptor = getCheckerClassMethodDescriptor;
+        this.classNameSuffix = classNameSuffix;
+        this.checkMethods = checkMethods;
+    }
+
+    static String getCheckerClassName() {
         int javaVersion = Runtime.version().feature();
         final String classNamePrefix;
         if (javaVersion >= 23) {
@@ -47,20 +66,14 @@ public class InstrumenterImpl implements Instrumenter {
         } else {
             classNamePrefix = "";
         }
-        String checkerClass = "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker";
-        handleClass = checkerClass + "Handle";
-        checkerClassDescriptor = Type.getObjectType(checkerClass).getDescriptor();
+        return "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker";
     }
 
-    /**
-     * To avoid class name collisions during testing without an agent to replace classes in-place.
-     */
-    private final String classNameSuffix;
-    private final Map<MethodKey, CheckerMethod> instrumentationMethods;
-
-    public InstrumenterImpl(String classNameSuffix, Map<MethodKey, CheckerMethod> instrumentationMethods) {
-        this.classNameSuffix = classNameSuffix;
-        this.instrumentationMethods = instrumentationMethods;
+    public static InstrumenterImpl create(Map<MethodKey, CheckMethod> checkMethods) {
+        String checkerClass = getCheckerClassName();
+        String handleClass = checkerClass + "Handle";
+        String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
+        return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
     }
 
     public ClassFileInfo instrumentClassFile(Class<?> clazz) throws IOException {
@@ -156,7 +169,7 @@ public class InstrumenterImpl implements Instrumenter {
                 boolean isStatic = (access & ACC_STATIC) != 0;
                 boolean isCtor = "<init>".equals(name);
                 var key = new MethodKey(className, name, Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList());
-                var instrumentationMethod = instrumentationMethods.get(key);
+                var instrumentationMethod = checkMethods.get(key);
                 if (instrumentationMethod != null) {
                     // LOGGER.debug("Will instrument method {}", key);
                     return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, isCtor, descriptor, instrumentationMethod);
@@ -190,7 +203,7 @@ public class InstrumenterImpl implements Instrumenter {
         private final boolean instrumentedMethodIsStatic;
         private final boolean instrumentedMethodIsCtor;
         private final String instrumentedMethodDescriptor;
-        private final CheckerMethod instrumentationMethod;
+        private final CheckMethod checkMethod;
         private boolean hasCallerSensitiveAnnotation = false;
 
         EntitlementMethodVisitor(
@@ -199,13 +212,13 @@ public class InstrumenterImpl implements Instrumenter {
             boolean instrumentedMethodIsStatic,
             boolean instrumentedMethodIsCtor,
             String instrumentedMethodDescriptor,
-            CheckerMethod instrumentationMethod
+            CheckMethod checkMethod
         ) {
             super(api, methodVisitor);
             this.instrumentedMethodIsStatic = instrumentedMethodIsStatic;
             this.instrumentedMethodIsCtor = instrumentedMethodIsCtor;
             this.instrumentedMethodDescriptor = instrumentedMethodDescriptor;
-            this.instrumentationMethod = instrumentationMethod;
+            this.checkMethod = checkMethod;
         }
 
         @Override
@@ -278,11 +291,11 @@ public class InstrumenterImpl implements Instrumenter {
         private void invokeInstrumentationMethod() {
             mv.visitMethodInsn(
                 INVOKEINTERFACE,
-                instrumentationMethod.className(),
-                instrumentationMethod.methodName(),
+                checkMethod.className(),
+                checkMethod.methodName(),
                 Type.getMethodDescriptor(
                     Type.VOID_TYPE,
-                    instrumentationMethod.parameterDescriptors().stream().map(Type::getType).toArray(Type[]::new)
+                    checkMethod.parameterDescriptors().stream().map(Type::getType).toArray(Type[]::new)
                 ),
                 true
             );
@@ -290,7 +303,7 @@ public class InstrumenterImpl implements Instrumenter {
     }
 
     protected void pushEntitlementChecker(MethodVisitor mv) {
-        mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", "()" + checkerClassDescriptor, false);
+        mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", getCheckerClassMethodDescriptor, false);
     }
 
     public record ClassFileInfo(String fileName, byte[] bytecodes) {}

+ 21 - 21
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java

@@ -9,7 +9,7 @@
 
 package org.elasticsearch.entitlement.instrumentation.impl;
 
-import org.elasticsearch.entitlement.instrumentation.CheckerMethod;
+import org.elasticsearch.entitlement.instrumentation.CheckMethod;
 import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
 import org.elasticsearch.entitlement.instrumentation.MethodKey;
 import org.elasticsearch.test.ESTestCase;
@@ -52,15 +52,15 @@ public class InstrumentationServiceImplTests extends ESTestCase {
     }
 
     public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException {
-        Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
+        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
 
-        assertThat(methodsMap, aMapWithSize(3));
+        assertThat(checkMethods, aMapWithSize(3));
         assertThat(
-            methodsMap,
+            checkMethods,
             hasEntry(
                 equalTo(new MethodKey("org/example/TestTargetClass", "staticMethod", List.of("I", "java/lang/String", "java/lang/Object"))),
                 equalTo(
-                    new CheckerMethod(
+                    new CheckMethod(
                         "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker",
                         "check$org_example_TestTargetClass$staticMethod",
                         List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;", "Ljava/lang/Object;")
@@ -69,7 +69,7 @@ public class InstrumentationServiceImplTests extends ESTestCase {
             )
         );
         assertThat(
-            methodsMap,
+            checkMethods,
             hasEntry(
                 equalTo(
                     new MethodKey(
@@ -79,7 +79,7 @@ public class InstrumentationServiceImplTests extends ESTestCase {
                     )
                 ),
                 equalTo(
-                    new CheckerMethod(
+                    new CheckMethod(
                         "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker",
                         "check$$instanceMethodNoArgs",
                         List.of(
@@ -91,7 +91,7 @@ public class InstrumentationServiceImplTests extends ESTestCase {
             )
         );
         assertThat(
-            methodsMap,
+            checkMethods,
             hasEntry(
                 equalTo(
                     new MethodKey(
@@ -101,7 +101,7 @@ public class InstrumentationServiceImplTests extends ESTestCase {
                     )
                 ),
                 equalTo(
-                    new CheckerMethod(
+                    new CheckMethod(
                         "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker",
                         "check$$instanceMethodWithArgs",
                         List.of(
@@ -117,15 +117,15 @@ public class InstrumentationServiceImplTests extends ESTestCase {
     }
 
     public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException {
-        Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName());
+        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName());
 
-        assertThat(methodsMap, aMapWithSize(2));
+        assertThat(checkMethods, aMapWithSize(2));
         assertThat(
-            methodsMap,
+            checkMethods,
             hasEntry(
                 equalTo(new MethodKey("org/example/TestTargetClass", "staticMethodWithOverload", List.of("I", "java/lang/String"))),
                 equalTo(
-                    new CheckerMethod(
+                    new CheckMethod(
                         "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerOverloads",
                         "check$org_example_TestTargetClass$staticMethodWithOverload",
                         List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;")
@@ -134,11 +134,11 @@ public class InstrumentationServiceImplTests extends ESTestCase {
             )
         );
         assertThat(
-            methodsMap,
+            checkMethods,
             hasEntry(
                 equalTo(new MethodKey("org/example/TestTargetClass", "staticMethodWithOverload", List.of("I", "I"))),
                 equalTo(
-                    new CheckerMethod(
+                    new CheckMethod(
                         "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerOverloads",
                         "check$org_example_TestTargetClass$staticMethodWithOverload",
                         List.of("Ljava/lang/Class;", "I", "I")
@@ -149,15 +149,15 @@ public class InstrumentationServiceImplTests extends ESTestCase {
     }
 
     public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException {
-        Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
+        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
 
-        assertThat(methodsMap, aMapWithSize(2));
+        assertThat(checkMethods, aMapWithSize(2));
         assertThat(
-            methodsMap,
+            checkMethods,
             hasEntry(
                 equalTo(new MethodKey("org/example/TestTargetClass", "<init>", List.of("I", "java/lang/String"))),
                 equalTo(
-                    new CheckerMethod(
+                    new CheckMethod(
                         "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerCtors",
                         "check$org_example_TestTargetClass$",
                         List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;")
@@ -166,11 +166,11 @@ public class InstrumentationServiceImplTests extends ESTestCase {
             )
         );
         assertThat(
-            methodsMap,
+            checkMethods,
             hasEntry(
                 equalTo(new MethodKey("org/example/TestTargetClass", "<init>", List.of())),
                 equalTo(
-                    new CheckerMethod(
+                    new CheckMethod(
                         "org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerCtors",
                         "check$org_example_TestTargetClass$",
                         List.of("Ljava/lang/Class;")

+ 77 - 301
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java

@@ -9,10 +9,8 @@
 
 package org.elasticsearch.entitlement.instrumentation.impl;
 
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.entitlement.bridge.EntitlementChecker;
-import org.elasticsearch.entitlement.instrumentation.CheckerMethod;
-import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
+import org.elasticsearch.entitlement.instrumentation.CheckMethod;
 import org.elasticsearch.entitlement.instrumentation.MethodKey;
 import org.elasticsearch.logging.LogManager;
 import org.elasticsearch.logging.Logger;
@@ -23,16 +21,21 @@ import org.objectweb.asm.Type;
 
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
+import java.net.MalformedURLException;
+import java.net.URI;
 import java.net.URL;
 import java.net.URLStreamHandlerFactory;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 
 import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
-import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
+import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod;
+import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod;
+import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForConstructor;
+import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget;
+import static org.hamcrest.Matchers.arrayContaining;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
-import static org.hamcrest.Matchers.startsWith;
 import static org.objectweb.asm.Opcodes.INVOKESTATIC;
 
 /**
@@ -42,7 +45,6 @@ import static org.objectweb.asm.Opcodes.INVOKESTATIC;
  */
 @ESTestCase.WithoutSecurityManager
 public class InstrumenterTests extends ESTestCase {
-    final InstrumentationService instrumentationService = new InstrumentationServiceImpl();
 
     static volatile TestEntitlementChecker testChecker;
 
@@ -59,12 +61,7 @@ public class InstrumenterTests extends ESTestCase {
      * Contains all the virtual methods from {@link ClassToInstrument},
      * allowing this test to call them on the dynamically loaded instrumented class.
      */
-    public interface Testable {
-        // This method is here to demonstrate Instrumenter does not get confused by overloads
-        void someMethod(int arg);
-
-        void someMethod(int arg, String anotherArg);
-    }
+    public interface Testable {}
 
     /**
      * This is a placeholder for real class library methods.
@@ -78,41 +75,24 @@ public class InstrumenterTests extends ESTestCase {
 
         public ClassToInstrument() {}
 
-        public ClassToInstrument(int arg) {}
+        // URLClassLoader ctor
+        public ClassToInstrument(URL[] urls) {}
 
         public static void systemExit(int status) {
             assertEquals(123, status);
         }
-
-        public static void anotherSystemExit(int status) {
-            assertEquals(123, status);
-        }
-
-        public void someMethod(int arg) {}
-
-        public void someMethod(int arg, String anotherArg) {}
-
-        public static void someStaticMethod(int arg) {}
-
-        public static void someStaticMethod(int arg, String anotherArg) {}
     }
 
-    static final class TestException extends RuntimeException {}
+    private static final String SAMPLE_NAME = "TEST";
 
-    /**
-     * Interface to test specific, "synthetic" cases (e.g. overloaded methods, overloaded constructors, etc.) that
-     * may be not present/may be difficult to find or not clear in the production EntitlementChecker interface
-     */
-    public interface MockEntitlementChecker extends EntitlementChecker {
-        void checkSomeStaticMethod(Class<?> clazz, int arg);
-
-        void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
-
-        void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
+    private static final URL SAMPLE_URL = createSampleUrl();
 
-        void checkCtor(Class<?> clazz);
-
-        void checkCtor(Class<?> clazz, int arg);
+    private static URL createSampleUrl() {
+        try {
+            return URI.create("file:/test/example").toURL();
+        } catch (MalformedURLException e) {
+            return null;
+        }
     }
 
     /**
@@ -122,7 +102,7 @@ public class InstrumenterTests extends ESTestCase {
      * just to demonstrate that the injected bytecodes succeed in calling these methods.
      * It also asserts that the arguments are correct.
      */
-    public static class TestEntitlementChecker implements MockEntitlementChecker {
+    public static class TestEntitlementChecker implements EntitlementChecker {
         /**
          * This allows us to test that the instrumentation is correct in both cases:
          * if the check throws, and if it doesn't.
@@ -130,104 +110,84 @@ public class InstrumenterTests extends ESTestCase {
         volatile boolean isActive;
 
         int checkSystemExitCallCount = 0;
-        int checkSomeStaticMethodIntCallCount = 0;
-        int checkSomeStaticMethodIntStringCallCount = 0;
-        int checkSomeInstanceMethodCallCount = 0;
-
-        int checkCtorCallCount = 0;
-        int checkCtorIntCallCount = 0;
+        int checkURLClassLoaderCallCount = 0;
 
         @Override
         public void check$java_lang_System$exit(Class<?> callerClass, int status) {
             checkSystemExitCallCount++;
-            assertSame(InstrumenterTests.class, callerClass);
+            assertSame(TestMethodUtils.class, callerClass);
             assertEquals(123, status);
             throwIfActive();
         }
 
         @Override
-        public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls) {}
-
-        @Override
-        public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent) {}
-
-        @Override
-        public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory) {}
-
-        @Override
-        public void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent) {}
-
-        @Override
-        public void check$java_net_URLClassLoader$(
-            Class<?> callerClass,
-            String name,
-            URL[] urls,
-            ClassLoader parent,
-            URLStreamHandlerFactory factory
-        ) {}
-
-        private void throwIfActive() {
-            if (isActive) {
-                throw new TestException();
-            }
-        }
-
-        @Override
-        public void checkSomeStaticMethod(Class<?> callerClass, int arg) {
-            checkSomeStaticMethodIntCallCount++;
+        public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls) {
+            checkURLClassLoaderCallCount++;
             assertSame(InstrumenterTests.class, callerClass);
-            assertEquals(123, arg);
+            assertThat(urls, arrayContaining(SAMPLE_URL));
             throwIfActive();
         }
 
         @Override
-        public void checkSomeStaticMethod(Class<?> callerClass, int arg, String anotherArg) {
-            checkSomeStaticMethodIntStringCallCount++;
+        public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent) {
+            checkURLClassLoaderCallCount++;
             assertSame(InstrumenterTests.class, callerClass);
-            assertEquals(123, arg);
-            assertEquals("abc", anotherArg);
+            assertThat(urls, arrayContaining(SAMPLE_URL));
+            assertThat(parent, equalTo(ClassLoader.getSystemClassLoader()));
             throwIfActive();
         }
 
         @Override
-        public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg, String anotherArg) {
-            checkSomeInstanceMethodCallCount++;
+        public void check$java_net_URLClassLoader$(Class<?> callerClass, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory) {
+            checkURLClassLoaderCallCount++;
             assertSame(InstrumenterTests.class, callerClass);
-            assertThat(
-                that.getClass().getName(),
-                startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$ClassToInstrument")
-            );
-            assertEquals(123, arg);
-            assertEquals("def", anotherArg);
+            assertThat(urls, arrayContaining(SAMPLE_URL));
+            assertThat(parent, equalTo(ClassLoader.getSystemClassLoader()));
             throwIfActive();
         }
 
         @Override
-        public void checkCtor(Class<?> callerClass) {
-            checkCtorCallCount++;
+        public void check$java_net_URLClassLoader$(Class<?> callerClass, String name, URL[] urls, ClassLoader parent) {
+            checkURLClassLoaderCallCount++;
             assertSame(InstrumenterTests.class, callerClass);
+            assertThat(name, equalTo(SAMPLE_NAME));
+            assertThat(urls, arrayContaining(SAMPLE_URL));
+            assertThat(parent, equalTo(ClassLoader.getSystemClassLoader()));
             throwIfActive();
         }
 
         @Override
-        public void checkCtor(Class<?> callerClass, int arg) {
-            checkCtorIntCallCount++;
+        public void check$java_net_URLClassLoader$(
+            Class<?> callerClass,
+            String name,
+            URL[] urls,
+            ClassLoader parent,
+            URLStreamHandlerFactory factory
+        ) {
+            checkURLClassLoaderCallCount++;
             assertSame(InstrumenterTests.class, callerClass);
-            assertEquals(123, arg);
+            assertThat(name, equalTo(SAMPLE_NAME));
+            assertThat(urls, arrayContaining(SAMPLE_URL));
+            assertThat(parent, equalTo(ClassLoader.getSystemClassLoader()));
             throwIfActive();
         }
+
+        private void throwIfActive() {
+            if (isActive) {
+                throw new TestException();
+            }
+        }
     }
 
-    public void testClassIsInstrumented() throws Exception {
+    public void testSystemExitIsInstrumented() throws Exception {
         var classToInstrument = ClassToInstrument.class;
 
-        CheckerMethod checkerMethod = getCheckerMethod(EntitlementChecker.class, "check$java_lang_System$exit", Class.class, int.class);
-        Map<MethodKey, CheckerMethod> methods = Map.of(
-            instrumentationService.methodKeyForTarget(classToInstrument.getMethod("systemExit", int.class)),
-            checkerMethod
+        Map<MethodKey, CheckMethod> checkMethods = Map.of(
+            methodKeyForTarget(classToInstrument.getMethod("systemExit", int.class)),
+            getCheckMethod(EntitlementChecker.class, "check$java_lang_System$exit", Class.class, int.class)
         );
 
-        var instrumenter = createInstrumenter(methods);
+        var instrumenter = createInstrumenter(checkMethods);
 
         byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
 
@@ -251,86 +211,15 @@ public class InstrumenterTests extends ESTestCase {
         assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
     }
 
-    public void testClassIsNotInstrumentedTwice() throws Exception {
-        var classToInstrument = ClassToInstrument.class;
-
-        CheckerMethod checkerMethod = getCheckerMethod(EntitlementChecker.class, "check$java_lang_System$exit", Class.class, int.class);
-        Map<MethodKey, CheckerMethod> methods = Map.of(
-            instrumentationService.methodKeyForTarget(classToInstrument.getMethod("systemExit", int.class)),
-            checkerMethod
-        );
-
-        var instrumenter = createInstrumenter(methods);
-
-        InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
-        var internalClassName = Type.getInternalName(classToInstrument);
-
-        byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
-        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
-
-        logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
-        logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
-
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW_NEW",
-            instrumentedTwiceBytecode
-        );
-
-        getTestEntitlementChecker().isActive = true;
-        getTestEntitlementChecker().checkSystemExitCallCount = 0;
-
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
-        assertEquals(1, getTestEntitlementChecker().checkSystemExitCallCount);
-    }
-
-    public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
+    public void testURLClassLoaderIsInstrumented() throws Exception {
         var classToInstrument = ClassToInstrument.class;
 
-        CheckerMethod checkerMethod = getCheckerMethod(EntitlementChecker.class, "check$java_lang_System$exit", Class.class, int.class);
-        Map<MethodKey, CheckerMethod> methods = Map.of(
-            instrumentationService.methodKeyForTarget(classToInstrument.getMethod("systemExit", int.class)),
-            checkerMethod,
-            instrumentationService.methodKeyForTarget(classToInstrument.getMethod("anotherSystemExit", int.class)),
-            checkerMethod
+        Map<MethodKey, CheckMethod> checkMethods = Map.of(
+            methodKeyForConstructor(classToInstrument, List.of(Type.getInternalName(URL[].class))),
+            getCheckMethod(EntitlementChecker.class, "check$java_net_URLClassLoader$", Class.class, URL[].class)
         );
 
-        var instrumenter = createInstrumenter(methods);
-
-        InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
-        var internalClassName = Type.getInternalName(classToInstrument);
-
-        byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
-        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
-
-        logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
-        logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
-
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW_NEW",
-            instrumentedTwiceBytecode
-        );
-
-        getTestEntitlementChecker().isActive = true;
-        getTestEntitlementChecker().checkSystemExitCallCount = 0;
-
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "systemExit", 123));
-        assertEquals(1, getTestEntitlementChecker().checkSystemExitCallCount);
-
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherSystemExit", 123));
-        assertEquals(2, getTestEntitlementChecker().checkSystemExitCallCount);
-    }
-
-    public void testInstrumenterWorksWithOverloads() throws Exception {
-        var classToInstrument = ClassToInstrument.class;
-
-        Map<MethodKey, CheckerMethod> methods = Map.of(
-            instrumentationService.methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
-            getCheckerMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class),
-            instrumentationService.methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)),
-            getCheckerMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class)
-        );
-
-        var instrumenter = createInstrumenter(methods);
+        var instrumenter = createInstrumenter(checkMethods);
 
         byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
 
@@ -343,80 +232,19 @@ public class InstrumenterTests extends ESTestCase {
             newBytecode
         );
 
-        getTestEntitlementChecker().isActive = true;
-
-        // After checking is activated, everything should throw
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
-
-        assertEquals(1, getTestEntitlementChecker().checkSomeStaticMethodIntCallCount);
-        assertEquals(1, getTestEntitlementChecker().checkSomeStaticMethodIntStringCallCount);
-    }
-
-    public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
-        var classToInstrument = ClassToInstrument.class;
-
-        Map<MethodKey, CheckerMethod> methods = Map.of(
-            instrumentationService.methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)),
-            getCheckerMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class)
-        );
-
-        var instrumenter = createInstrumenter(methods);
-
-        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
-
-        if (logger.isTraceEnabled()) {
-            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
-        }
+        getTestEntitlementChecker().isActive = false;
 
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW",
-            newBytecode
-        );
+        // Before checking is active, nothing should throw
+        newClass.getConstructor(URL[].class).newInstance((Object) new URL[] { SAMPLE_URL });
 
         getTestEntitlementChecker().isActive = true;
 
-        Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance());
-
-        // This overload is not instrumented, so it will not throw
-        testTargetClass.someMethod(123);
-        assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
-
-        assertEquals(1, getTestEntitlementChecker().checkSomeInstanceMethodCallCount);
-    }
-
-    public void testInstrumenterWorksWithConstructors() throws Exception {
-        var classToInstrument = ClassToInstrument.class;
-
-        Map<MethodKey, CheckerMethod> methods = Map.of(
-            new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
-            getCheckerMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
-            new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
-            getCheckerMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
-        );
-
-        var instrumenter = createInstrumenter(methods);
-
-        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
-
-        if (logger.isTraceEnabled()) {
-            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
-        }
-
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW",
-            newBytecode
+        // After checking is activated, everything should throw
+        var exception = assertThrows(
+            InvocationTargetException.class,
+            () -> newClass.getConstructor(URL[].class).newInstance((Object) new URL[] { SAMPLE_URL })
         );
-
-        getTestEntitlementChecker().isActive = true;
-
-        var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
-        assertThat(ex.getCause(), instanceOf(TestException.class));
-        var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
-        assertThat(ex2.getCause(), instanceOf(TestException.class));
-
-        assertEquals(1, getTestEntitlementChecker().checkCtorCallCount);
-        assertEquals(1, getTestEntitlementChecker().checkCtorIntCallCount);
+        assertThat(exception.getCause(), instanceOf(TestException.class));
     }
 
     /** This test doesn't replace classToInstrument in-place but instead loads a separate
@@ -425,9 +253,10 @@ public class InstrumenterTests extends ESTestCase {
      * MethodKey and instrumentationMethod with slightly different signatures (using the common interface
      * Testable) which is not what would happen when it's run by the agent.
      */
-    private InstrumenterImpl createInstrumenter(Map<MethodKey, CheckerMethod> methods) throws NoSuchMethodException {
+    private InstrumenterImpl createInstrumenter(Map<MethodKey, CheckMethod> checkMethods) throws NoSuchMethodException {
         Method getter = InstrumenterTests.class.getMethod("getTestEntitlementChecker");
-        return new InstrumenterImpl("_NEW", methods) {
+
+        return new InstrumenterImpl(null, null, "_NEW", checkMethods) {
             /**
              * We're not testing the bridge library here.
              * Just call our own getter instead.
@@ -445,58 +274,5 @@ public class InstrumenterTests extends ESTestCase {
         };
     }
 
-    private static CheckerMethod getCheckerMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes)
-        throws NoSuchMethodException {
-        var method = clazz.getMethod(methodName, parameterTypes);
-        return new CheckerMethod(
-            Type.getInternalName(clazz),
-            method.getName(),
-            Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList()
-        );
-    }
-
-    /**
-     * Calling a static method of a dynamically loaded class is significantly more cumbersome
-     * than calling a virtual method.
-     */
-    private static void callStaticMethod(Class<?> c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException {
-        try {
-            c.getMethod(methodName, int.class).invoke(null, arg);
-        } catch (InvocationTargetException e) {
-            Throwable cause = e.getCause();
-            if (cause instanceof TestException n) {
-                // Sometimes we're expecting this one!
-                throw n;
-            } else {
-                throw new AssertionError(cause);
-            }
-        }
-    }
-
-    private static void callStaticMethod(Class<?> c, String methodName, int arg1, String arg2) throws NoSuchMethodException,
-        IllegalAccessException {
-        try {
-            c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2);
-        } catch (InvocationTargetException e) {
-            Throwable cause = e.getCause();
-            if (cause instanceof TestException n) {
-                // Sometimes we're expecting this one!
-                throw n;
-            } else {
-                throw new AssertionError(cause);
-            }
-        }
-    }
-
-    static class TestLoader extends ClassLoader {
-        TestLoader(ClassLoader parent) {
-            super(parent);
-        }
-
-        public Class<?> defineClassFromBytes(String name, byte[] bytes) {
-            return defineClass(name, bytes, 0, bytes.length);
-        }
-    }
-
     private static final Logger logger = LogManager.getLogger(InstrumenterTests.class);
 }

+ 383 - 0
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/SyntheticInstrumenterTests.java

@@ -0,0 +1,383 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.entitlement.instrumentation.impl;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.entitlement.instrumentation.CheckMethod;
+import org.elasticsearch.entitlement.instrumentation.MethodKey;
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
+import org.elasticsearch.test.ESTestCase;
+import org.objectweb.asm.Type;
+
+import java.lang.reflect.InvocationTargetException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
+import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
+import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod;
+import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod;
+import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.startsWith;
+
+/**
+ * This tests {@link InstrumenterImpl} with some ad-hoc instrumented method and checker methods, to allow us to check
+ * some ad-hoc test cases (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
+ */
+@ESTestCase.WithoutSecurityManager
+public class SyntheticInstrumenterTests extends ESTestCase {
+    private static final Logger logger = LogManager.getLogger(SyntheticInstrumenterTests.class);
+
+    /**
+     * Contains all the virtual methods from {@link TestClassToInstrument},
+     * allowing this test to call them on the dynamically loaded instrumented class.
+     */
+    public interface Testable {
+        // This method is here to demonstrate Instrumenter does not get confused by overloads
+        void someMethod(int arg);
+
+        void someMethod(int arg, String anotherArg);
+    }
+
+    /**
+     * This is a placeholder for real class library methods.
+     * Without the java agent, we can't instrument the real methods, so we instrument this instead.
+     * <p>
+     * Methods of this class must have the same signature and the same static/virtual condition as the corresponding real method.
+     * They should assert that the arguments came through correctly.
+     * They must not throw {@link TestException}.
+     */
+    public static class TestClassToInstrument implements Testable {
+
+        public TestClassToInstrument() {}
+
+        public TestClassToInstrument(int arg) {}
+
+        public void someMethod(int arg) {}
+
+        public void someMethod(int arg, String anotherArg) {}
+
+        public static void someStaticMethod(int arg) {}
+
+        public static void someStaticMethod(int arg, String anotherArg) {}
+
+        public static void anotherStaticMethod(int arg) {}
+    }
+
+    /**
+     * Interface to test specific, "synthetic" cases (e.g. overloaded methods, overloaded constructors, etc.) that
+     * may be not present/may be difficult to find or not clear in the production EntitlementChecker interface
+     */
+    public interface MockEntitlementChecker {
+        void checkSomeStaticMethod(Class<?> clazz, int arg);
+
+        void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
+
+        void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
+
+        void checkCtor(Class<?> clazz);
+
+        void checkCtor(Class<?> clazz, int arg);
+    }
+
+    public static class TestEntitlementCheckerHolder {
+        static TestEntitlementChecker checkerInstance = new TestEntitlementChecker();
+
+        public static MockEntitlementChecker instance() {
+            return checkerInstance;
+        }
+    }
+
+    public static class TestEntitlementChecker implements MockEntitlementChecker {
+        /**
+         * This allows us to test that the instrumentation is correct in both cases:
+         * if the check throws, and if it doesn't.
+         */
+        volatile boolean isActive;
+
+        int checkSomeStaticMethodIntCallCount = 0;
+        int checkSomeStaticMethodIntStringCallCount = 0;
+        int checkSomeInstanceMethodCallCount = 0;
+
+        int checkCtorCallCount = 0;
+        int checkCtorIntCallCount = 0;
+
+        private void throwIfActive() {
+            if (isActive) {
+                throw new TestException();
+            }
+        }
+
+        @Override
+        public void checkSomeStaticMethod(Class<?> callerClass, int arg) {
+            checkSomeStaticMethodIntCallCount++;
+            assertSame(TestMethodUtils.class, callerClass);
+            assertEquals(123, arg);
+            throwIfActive();
+        }
+
+        @Override
+        public void checkSomeStaticMethod(Class<?> callerClass, int arg, String anotherArg) {
+            checkSomeStaticMethodIntStringCallCount++;
+            assertSame(TestMethodUtils.class, callerClass);
+            assertEquals(123, arg);
+            assertEquals("abc", anotherArg);
+            throwIfActive();
+        }
+
+        @Override
+        public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg, String anotherArg) {
+            checkSomeInstanceMethodCallCount++;
+            assertSame(SyntheticInstrumenterTests.class, callerClass);
+            assertThat(
+                that.getClass().getName(),
+                startsWith("org.elasticsearch.entitlement.instrumentation.impl.SyntheticInstrumenterTests$TestClassToInstrument")
+            );
+            assertEquals(123, arg);
+            assertEquals("def", anotherArg);
+            throwIfActive();
+        }
+
+        @Override
+        public void checkCtor(Class<?> callerClass) {
+            checkCtorCallCount++;
+            assertSame(SyntheticInstrumenterTests.class, callerClass);
+            throwIfActive();
+        }
+
+        @Override
+        public void checkCtor(Class<?> callerClass, int arg) {
+            checkCtorIntCallCount++;
+            assertSame(SyntheticInstrumenterTests.class, callerClass);
+            assertEquals(123, arg);
+            throwIfActive();
+        }
+    }
+
+    public void testClassIsInstrumented() throws Exception {
+        var classToInstrument = TestClassToInstrument.class;
+
+        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
+        Map<MethodKey, CheckMethod> checkMethods = Map.of(
+            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
+            checkMethod
+        );
+
+        var instrumenter = createInstrumenter(checkMethods);
+
+        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
+
+        if (logger.isTraceEnabled()) {
+            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
+        }
+
+        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
+            classToInstrument.getName() + "_NEW",
+            newBytecode
+        );
+
+        TestEntitlementCheckerHolder.checkerInstance.isActive = false;
+
+        // Before checking is active, nothing should throw
+        callStaticMethod(newClass, "someStaticMethod", 123);
+
+        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
+
+        // After checking is activated, everything should throw
+        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
+    }
+
+    public void testClassIsNotInstrumentedTwice() throws Exception {
+        var classToInstrument = TestClassToInstrument.class;
+
+        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
+        Map<MethodKey, CheckMethod> checkMethods = Map.of(
+            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
+            checkMethod
+        );
+
+        var instrumenter = createInstrumenter(checkMethods);
+
+        InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
+        var internalClassName = Type.getInternalName(classToInstrument);
+
+        byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
+        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
+
+        logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
+        logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
+
+        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
+            classToInstrument.getName() + "_NEW_NEW",
+            instrumentedTwiceBytecode
+        );
+
+        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
+        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
+
+        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
+        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
+    }
+
+    public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
+        var classToInstrument = TestClassToInstrument.class;
+
+        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
+        Map<MethodKey, CheckMethod> checkMethods = Map.of(
+            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
+            checkMethod,
+            methodKeyForTarget(classToInstrument.getMethod("anotherStaticMethod", int.class)),
+            checkMethod
+        );
+
+        var instrumenter = createInstrumenter(checkMethods);
+
+        InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
+        var internalClassName = Type.getInternalName(classToInstrument);
+
+        byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
+        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
+
+        logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
+        logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
+
+        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
+            classToInstrument.getName() + "_NEW_NEW",
+            instrumentedTwiceBytecode
+        );
+
+        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
+        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
+
+        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
+        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
+
+        assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherStaticMethod", 123));
+        assertEquals(2, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
+    }
+
+    public void testInstrumenterWorksWithOverloads() throws Exception {
+        var classToInstrument = TestClassToInstrument.class;
+
+        Map<MethodKey, CheckMethod> checkMethods = Map.of(
+            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
+            getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class),
+            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)),
+            getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class)
+        );
+
+        var instrumenter = createInstrumenter(checkMethods);
+
+        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
+
+        if (logger.isTraceEnabled()) {
+            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
+        }
+
+        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
+            classToInstrument.getName() + "_NEW",
+            newBytecode
+        );
+
+        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
+        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
+        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount = 0;
+
+        // After checking is activated, everything should throw
+        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
+        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
+
+        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
+        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount);
+    }
+
+    public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
+        var classToInstrument = TestClassToInstrument.class;
+
+        Map<MethodKey, CheckMethod> checkMethods = Map.of(
+            methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)),
+            getCheckMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class)
+        );
+
+        var instrumenter = createInstrumenter(checkMethods);
+
+        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
+
+        if (logger.isTraceEnabled()) {
+            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
+        }
+
+        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
+            classToInstrument.getName() + "_NEW",
+            newBytecode
+        );
+
+        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
+        TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount = 0;
+
+        Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance());
+
+        // This overload is not instrumented, so it will not throw
+        testTargetClass.someMethod(123);
+        assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
+
+        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount);
+    }
+
+    public void testInstrumenterWorksWithConstructors() throws Exception {
+        var classToInstrument = TestClassToInstrument.class;
+
+        Map<MethodKey, CheckMethod> checkMethods = Map.of(
+            new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
+            getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
+            new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
+            getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
+        );
+
+        var instrumenter = createInstrumenter(checkMethods);
+
+        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
+
+        if (logger.isTraceEnabled()) {
+            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
+        }
+
+        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
+            classToInstrument.getName() + "_NEW",
+            newBytecode
+        );
+
+        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
+
+        var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
+        assertThat(ex.getCause(), instanceOf(TestException.class));
+        var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
+        assertThat(ex2.getCause(), instanceOf(TestException.class));
+
+        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorCallCount);
+        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorIntCallCount);
+    }
+
+    /** This test doesn't replace classToInstrument in-place but instead loads a separate
+     * class with the same class name plus a "_NEW" suffix (classToInstrument.class.getName() + "_NEW")
+     * that contains the instrumentation. Because of this, we need to configure the Transformer to use a
+     * MethodKey and instrumentationMethod with slightly different signatures (using the common interface
+     * Testable) which is not what would happen when it's run by the agent.
+     */
+    private InstrumenterImpl createInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
+        String checkerClass = Type.getInternalName(SyntheticInstrumenterTests.MockEntitlementChecker.class);
+        String handleClass = Type.getInternalName(SyntheticInstrumenterTests.TestEntitlementCheckerHolder.class);
+        String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
+
+        return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "_NEW", checkMethods);
+    }
+}

+ 12 - 0
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestException.java

@@ -0,0 +1,12 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.entitlement.instrumentation.impl;
+
+final class TestException extends RuntimeException {}

+ 20 - 0
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestLoader.java

@@ -0,0 +1,20 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.entitlement.instrumentation.impl;
+
+class TestLoader extends ClassLoader {
+    TestLoader(ClassLoader parent) {
+        super(parent);
+    }
+
+    public Class<?> defineClassFromBytes(String name, byte[] bytes) {
+        return defineClass(name, bytes, 0, bytes.length);
+    }
+}

+ 81 - 0
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestMethodUtils.java

@@ -0,0 +1,81 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.entitlement.instrumentation.impl;
+
+import org.elasticsearch.entitlement.instrumentation.CheckMethod;
+import org.elasticsearch.entitlement.instrumentation.MethodKey;
+import org.objectweb.asm.Type;
+
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Stream;
+
+class TestMethodUtils {
+
+    /**
+     * @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline
+     */
+    static MethodKey methodKeyForTarget(Method targetMethod) {
+        Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod));
+        return new MethodKey(
+            Type.getInternalName(targetMethod.getDeclaringClass()),
+            targetMethod.getName(),
+            Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList()
+        );
+    }
+
+    static MethodKey methodKeyForConstructor(Class<?> classToInstrument, List<String> params) {
+        return new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", params);
+    }
+
+    static CheckMethod getCheckMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes) throws NoSuchMethodException {
+        var method = clazz.getMethod(methodName, parameterTypes);
+        return new CheckMethod(
+            Type.getInternalName(clazz),
+            method.getName(),
+            Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList()
+        );
+    }
+
+    /**
+     * Calling a static method of a dynamically loaded class is significantly more cumbersome
+     * than calling a virtual method.
+     */
+    static void callStaticMethod(Class<?> c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException {
+        try {
+            c.getMethod(methodName, int.class).invoke(null, arg);
+        } catch (InvocationTargetException e) {
+            Throwable cause = e.getCause();
+            if (cause instanceof TestException n) {
+                // Sometimes we're expecting this one!
+                throw n;
+            } else {
+                throw new AssertionError(cause);
+            }
+        }
+    }
+
+    static void callStaticMethod(Class<?> c, String methodName, int arg1, String arg2) throws NoSuchMethodException,
+        IllegalAccessException {
+        try {
+            c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2);
+        } catch (InvocationTargetException e) {
+            Throwable cause = e.getCause();
+            if (cause instanceof TestException n) {
+                // Sometimes we're expecting this one!
+                throw n;
+            } else {
+                throw new AssertionError(cause);
+            }
+        }
+    }
+}

+ 4 - 4
libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java

@@ -13,7 +13,7 @@ import org.elasticsearch.core.Tuple;
 import org.elasticsearch.core.internal.provider.ProviderLocator;
 import org.elasticsearch.entitlement.bootstrap.EntitlementBootstrap;
 import org.elasticsearch.entitlement.bridge.EntitlementChecker;
-import org.elasticsearch.entitlement.instrumentation.CheckerMethod;
+import org.elasticsearch.entitlement.instrumentation.CheckMethod;
 import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
 import org.elasticsearch.entitlement.instrumentation.MethodKey;
 import org.elasticsearch.entitlement.instrumentation.Transformer;
@@ -63,13 +63,13 @@ public class EntitlementInitialization {
     public static void initialize(Instrumentation inst) throws Exception {
         manager = initChecker();
 
-        Map<MethodKey, CheckerMethod> methodMap = INSTRUMENTER_FACTORY.lookupMethodsToInstrument(
+        Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethodsToInstrument(
             "org.elasticsearch.entitlement.bridge.EntitlementChecker"
         );
 
-        var classesToTransform = methodMap.keySet().stream().map(MethodKey::className).collect(Collectors.toSet());
+        var classesToTransform = checkMethods.keySet().stream().map(MethodKey::className).collect(Collectors.toSet());
 
-        inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter("", methodMap), classesToTransform), true);
+        inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter(checkMethods), classesToTransform), true);
         // TODO: should we limit this array somehow?
         var classesToRetransform = classesToTransform.stream().map(EntitlementInitialization::internalNameToClass).toArray(Class[]::new);
         inst.retransformClasses(classesToRetransform);

+ 2 - 2
libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/CheckerMethod.java → libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/CheckMethod.java

@@ -12,7 +12,7 @@ package org.elasticsearch.entitlement.instrumentation;
 import java.util.List;
 
 /**
- * A structure to use as a representation of the checker method the instrumentation will inject.
+ * A structure to use as a representation of the checkXxx method the instrumentation will inject.
  *
  * @param className the "internal name" of the class: includes the package info, but with periods replaced by slashes
  * @param methodName the checker method name
@@ -20,4 +20,4 @@ import java.util.List;
  *                             <a href="https://docs.oracle.com/javase/specs/jvms/se23/html/jvms-4.html#jvms-4.3">type descriptors</a>)
  *                             for methodName parameters.
  */
-public record CheckerMethod(String className, String methodName, List<String> parameterDescriptors) {}
+public record CheckMethod(String className, String methodName, List<String> parameterDescriptors) {}

+ 2 - 8
libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java

@@ -10,19 +10,13 @@
 package org.elasticsearch.entitlement.instrumentation;
 
 import java.io.IOException;
-import java.lang.reflect.Method;
 import java.util.Map;
 
 /**
  * The SPI service entry point for instrumentation.
  */
 public interface InstrumentationService {
-    Instrumenter newInstrumenter(String classNameSuffix, Map<MethodKey, CheckerMethod> instrumentationMethods);
+    Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods);
 
-    /**
-     * @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline
-     */
-    MethodKey methodKeyForTarget(Method targetMethod);
-
-    Map<MethodKey, CheckerMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException;
+    Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException;
 }