Sfoglia il codice sorgente

[8.x] Simplify instrumenter and tests (#118493) (#118714)

This commit simplifies the entitlements instrumentation service and
instrumenter a bit. It especially removes some repetition in the
instrumenter tests.
Ryan Ernst 10 mesi fa
parent
commit
b7dc70ddae

+ 3 - 5
libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImpl.java

@@ -29,15 +29,13 @@ import java.util.Map;
 public class InstrumentationServiceImpl implements InstrumentationService {
 
     @Override
-    public Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
-        return InstrumenterImpl.create(checkMethods);
+    public Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods) {
+        return InstrumenterImpl.create(clazz, methods);
     }
 
     @Override
-    public Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException,
-        IOException {
+    public Map<MethodKey, CheckMethod> lookupMethods(Class<?> checkerClass) throws IOException {
         var methodsToInstrument = new HashMap<MethodKey, CheckMethod>();
-        var checkerClass = Class.forName(entitlementCheckerClassName);
         var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
         ClassReader reader = new ClassReader(classFileInfo.bytecodes());
         ClassVisitor visitor = new ClassVisitor(Opcodes.ASM9) {

+ 6 - 22
libs/entitlement/asm-provider/src/main/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterImpl.java

@@ -58,30 +58,14 @@ public class InstrumenterImpl implements Instrumenter {
         this.checkMethods = checkMethods;
     }
 
-    static String getCheckerClassName() {
-        int javaVersion = Runtime.version().feature();
-        final String classNamePrefix;
-        if (javaVersion >= 23) {
-            classNamePrefix = "Java23";
-        } else {
-            classNamePrefix = "";
-        }
-        return "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker";
-    }
-
-    public static InstrumenterImpl create(Map<MethodKey, CheckMethod> checkMethods) {
-        String checkerClass = getCheckerClassName();
-        String handleClass = checkerClass + "Handle";
-        String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
+    public static InstrumenterImpl create(Class<?> checkerClass, Map<MethodKey, CheckMethod> checkMethods) {
+        Type checkerClassType = Type.getType(checkerClass);
+        String handleClass = checkerClassType.getInternalName() + "Handle";
+        String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(checkerClassType);
         return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
     }
 
-    public ClassFileInfo instrumentClassFile(Class<?> clazz) throws IOException {
-        ClassFileInfo initial = getClassFileInfo(clazz);
-        return new ClassFileInfo(initial.fileName(), instrumentClass(Type.getInternalName(clazz), initial.bytecodes()));
-    }
-
-    public static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {
+    static ClassFileInfo getClassFileInfo(Class<?> clazz) throws IOException {
         String internalName = Type.getInternalName(clazz);
         String fileName = "/" + internalName + ".class";
         byte[] originalBytecodes;
@@ -306,5 +290,5 @@ public class InstrumenterImpl implements Instrumenter {
         mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", getCheckerClassMethodDescriptor, false);
     }
 
-    public record ClassFileInfo(String fileName, byte[] bytecodes) {}
+    record ClassFileInfo(String fileName, byte[] bytecodes) {}
 }

+ 6 - 6
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests.java

@@ -51,8 +51,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
         void check$org_example_TestTargetClass$(Class<?> clazz, int x, String y);
     }
 
-    public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException {
-        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
+    public void testInstrumentationTargetLookup() throws IOException {
+        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestChecker.class);
 
         assertThat(checkMethods, aMapWithSize(3));
         assertThat(
@@ -116,8 +116,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
         );
     }
 
-    public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException {
-        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName());
+    public void testInstrumentationTargetLookupWithOverloads() throws IOException {
+        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerOverloads.class);
 
         assertThat(checkMethods, aMapWithSize(2));
         assertThat(
@@ -148,8 +148,8 @@ public class InstrumentationServiceImplTests extends ESTestCase {
         );
     }
 
-    public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException {
-        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
+    public void testInstrumentationTargetLookupWithCtors() throws IOException {
+        Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethods(TestCheckerCtors.class);
 
         assertThat(checkMethods, aMapWithSize(2));
         assertThat(

+ 200 - 187
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/InstrumenterTests.java

@@ -12,31 +12,63 @@ package org.elasticsearch.entitlement.instrumentation.impl;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.entitlement.instrumentation.CheckMethod;
 import org.elasticsearch.entitlement.instrumentation.MethodKey;
+import org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.ClassFileInfo;
 import org.elasticsearch.logging.LogManager;
 import org.elasticsearch.logging.Logger;
 import org.elasticsearch.test.ESTestCase;
+import org.junit.Before;
 import org.objectweb.asm.Type;
 
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Executable;
 import java.lang.reflect.InvocationTargetException;
-import java.util.List;
+import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
+import java.util.Arrays;
+import java.util.HashMap;
 import java.util.Map;
+import java.util.stream.Stream;
 
 import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
 import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
-import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod;
-import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod;
-import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget;
-import static org.hamcrest.Matchers.instanceOf;
-import static org.hamcrest.Matchers.startsWith;
+import static org.hamcrest.Matchers.equalTo;
 
 /**
- * This tests {@link InstrumenterImpl} with some ad-hoc instrumented method and checker methods, to allow us to check
- * some ad-hoc test cases (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
+ * This tests {@link InstrumenterImpl} can instrument various method signatures
+ * (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
  */
 @ESTestCase.WithoutSecurityManager
 public class InstrumenterTests extends ESTestCase {
     private static final Logger logger = LogManager.getLogger(InstrumenterTests.class);
 
+    static class TestLoader extends ClassLoader {
+        final byte[] testClassBytes;
+        final Class<?> testClass;
+
+        TestLoader(String testClassName, byte[] testClassBytes) {
+            super(InstrumenterTests.class.getClassLoader());
+            this.testClassBytes = testClassBytes;
+            this.testClass = defineClass(testClassName, testClassBytes, 0, testClassBytes.length);
+        }
+
+        Method getSameMethod(Method method) {
+            try {
+                return testClass.getMethod(method.getName(), method.getParameterTypes());
+            } catch (NoSuchMethodException e) {
+                throw new AssertionError(e);
+            }
+        }
+
+        Constructor<?> getSameConstructor(Constructor<?> ctor) {
+            try {
+                return testClass.getConstructor(ctor.getParameterTypes());
+            } catch (NoSuchMethodException e) {
+                throw new AssertionError(e);
+            }
+        }
+    }
+
     /**
      * Contains all the virtual methods from {@link TestClassToInstrument},
      * allowing this test to call them on the dynamically loaded instrumented class.
@@ -80,13 +112,15 @@ public class InstrumenterTests extends ESTestCase {
     public interface MockEntitlementChecker {
         void checkSomeStaticMethod(Class<?> clazz, int arg);
 
-        void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
+        void checkSomeStaticMethodOverload(Class<?> clazz, int arg, String anotherArg);
+
+        void checkAnotherStaticMethod(Class<?> clazz, int arg);
 
         void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
 
         void checkCtor(Class<?> clazz);
 
-        void checkCtor(Class<?> clazz, int arg);
+        void checkCtorOverload(Class<?> clazz, int arg);
     }
 
     public static class TestEntitlementCheckerHolder {
@@ -105,6 +139,7 @@ public class InstrumenterTests extends ESTestCase {
         volatile boolean isActive;
 
         int checkSomeStaticMethodIntCallCount = 0;
+        int checkAnotherStaticMethodIntCallCount = 0;
         int checkSomeStaticMethodIntStringCallCount = 0;
         int checkSomeInstanceMethodCallCount = 0;
 
@@ -120,28 +155,33 @@ public class InstrumenterTests extends ESTestCase {
         @Override
         public void checkSomeStaticMethod(Class<?> callerClass, int arg) {
             checkSomeStaticMethodIntCallCount++;
-            assertSame(TestMethodUtils.class, callerClass);
+            assertSame(InstrumenterTests.class, callerClass);
             assertEquals(123, arg);
             throwIfActive();
         }
 
         @Override
-        public void checkSomeStaticMethod(Class<?> callerClass, int arg, String anotherArg) {
+        public void checkSomeStaticMethodOverload(Class<?> callerClass, int arg, String anotherArg) {
             checkSomeStaticMethodIntStringCallCount++;
-            assertSame(TestMethodUtils.class, callerClass);
+            assertSame(InstrumenterTests.class, callerClass);
             assertEquals(123, arg);
             assertEquals("abc", anotherArg);
             throwIfActive();
         }
 
+        @Override
+        public void checkAnotherStaticMethod(Class<?> callerClass, int arg) {
+            checkAnotherStaticMethodIntCallCount++;
+            assertSame(InstrumenterTests.class, callerClass);
+            assertEquals(123, arg);
+            throwIfActive();
+        }
+
         @Override
         public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg, String anotherArg) {
             checkSomeInstanceMethodCallCount++;
             assertSame(InstrumenterTests.class, callerClass);
-            assertThat(
-                that.getClass().getName(),
-                startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$TestClassToInstrument")
-            );
+            assertThat(that.getClass().getName(), equalTo(TestClassToInstrument.class.getName()));
             assertEquals(123, arg);
             assertEquals("def", anotherArg);
             throwIfActive();
@@ -155,7 +195,7 @@ public class InstrumenterTests extends ESTestCase {
         }
 
         @Override
-        public void checkCtor(Class<?> callerClass, int arg) {
+        public void checkCtorOverload(Class<?> callerClass, int arg) {
             checkCtorIntCallCount++;
             assertSame(InstrumenterTests.class, callerClass);
             assertEquals(123, arg);
@@ -163,206 +203,83 @@ public class InstrumenterTests extends ESTestCase {
         }
     }
 
-    public void testClassIsInstrumented() throws Exception {
-        var classToInstrument = TestClassToInstrument.class;
-
-        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
-        Map<MethodKey, CheckMethod> checkMethods = Map.of(
-            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
-            checkMethod
-        );
-
-        var instrumenter = createInstrumenter(checkMethods);
-
-        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
-
-        if (logger.isTraceEnabled()) {
-            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
-        }
-
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW",
-            newBytecode
-        );
+    @Before
+    public void resetInstance() {
+        TestEntitlementCheckerHolder.checkerInstance = new TestEntitlementChecker();
+    }
 
-        TestEntitlementCheckerHolder.checkerInstance.isActive = false;
+    public void testStaticMethod() throws Exception {
+        Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
+        TestLoader loader = instrumentTestClass(createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod)));
 
         // Before checking is active, nothing should throw
-        callStaticMethod(newClass, "someStaticMethod", 123);
-
-        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
-
+        assertStaticMethod(loader, targetMethod, 123);
         // After checking is activated, everything should throw
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
+        assertStaticMethodThrows(loader, targetMethod, 123);
     }
 
-    public void testClassIsNotInstrumentedTwice() throws Exception {
-        var classToInstrument = TestClassToInstrument.class;
-
-        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
-        Map<MethodKey, CheckMethod> checkMethods = Map.of(
-            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
-            checkMethod
-        );
-
-        var instrumenter = createInstrumenter(checkMethods);
-
-        InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
-        var internalClassName = Type.getInternalName(classToInstrument);
-
-        byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
-        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
+    public void testNotInstrumentedTwice() throws Exception {
+        Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
+        var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod));
 
-        logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
+        var loader1 = instrumentTestClass(instrumenter);
+        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(TestClassToInstrument.class.getName(), loader1.testClassBytes);
         logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
+        var loader2 = new TestLoader(TestClassToInstrument.class.getName(), instrumentedTwiceBytecode);
 
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW_NEW",
-            instrumentedTwiceBytecode
-        );
-
-        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
-        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
-
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
+        assertStaticMethodThrows(loader2, targetMethod, 123);
         assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
     }
 
-    public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
-        var classToInstrument = TestClassToInstrument.class;
-
-        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
-        Map<MethodKey, CheckMethod> checkMethods = Map.of(
-            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
-            checkMethod,
-            methodKeyForTarget(classToInstrument.getMethod("anotherStaticMethod", int.class)),
-            checkMethod
-        );
-
-        var instrumenter = createInstrumenter(checkMethods);
-
-        InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
-        var internalClassName = Type.getInternalName(classToInstrument);
+    public void testMultipleMethods() throws Exception {
+        Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
+        Method targetMethod2 = TestClassToInstrument.class.getMethod("anotherStaticMethod", int.class);
 
-        byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
-        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
+        var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod1, "checkAnotherStaticMethod", targetMethod2));
+        var loader = instrumentTestClass(instrumenter);
 
-        logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
-        logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
-
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW_NEW",
-            instrumentedTwiceBytecode
-        );
-
-        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
-        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
-
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
+        assertStaticMethodThrows(loader, targetMethod1, 123);
         assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
-
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherStaticMethod", 123));
-        assertEquals(2, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
+        assertStaticMethodThrows(loader, targetMethod2, 123);
+        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkAnotherStaticMethodIntCallCount);
     }
 
-    public void testInstrumenterWorksWithOverloads() throws Exception {
-        var classToInstrument = TestClassToInstrument.class;
-
-        Map<MethodKey, CheckMethod> checkMethods = Map.of(
-            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
-            getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class),
-            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)),
-            getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class)
+    public void testStaticMethodOverload() throws Exception {
+        Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
+        Method targetMethod2 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class, String.class);
+        var instrumenter = createInstrumenter(
+            Map.of("checkSomeStaticMethod", targetMethod1, "checkSomeStaticMethodOverload", targetMethod2)
         );
+        var loader = instrumentTestClass(instrumenter);
 
-        var instrumenter = createInstrumenter(checkMethods);
-
-        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
-
-        if (logger.isTraceEnabled()) {
-            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
-        }
-
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW",
-            newBytecode
-        );
-
-        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
-        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
-        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount = 0;
-
-        // After checking is activated, everything should throw
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
-        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
-
+        assertStaticMethodThrows(loader, targetMethod1, 123);
+        assertStaticMethodThrows(loader, targetMethod2, 123, "abc");
         assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
         assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount);
     }
 
-    public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
-        var classToInstrument = TestClassToInstrument.class;
-
-        Map<MethodKey, CheckMethod> checkMethods = Map.of(
-            methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)),
-            getCheckMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class)
-        );
-
-        var instrumenter = createInstrumenter(checkMethods);
-
-        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
-
-        if (logger.isTraceEnabled()) {
-            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
-        }
-
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW",
-            newBytecode
-        );
+    public void testInstanceMethodOverload() throws Exception {
+        Method targetMethod = TestClassToInstrument.class.getMethod("someMethod", int.class, String.class);
+        var instrumenter = createInstrumenter(Map.of("checkSomeInstanceMethod", targetMethod));
+        var loader = instrumentTestClass(instrumenter);
 
         TestEntitlementCheckerHolder.checkerInstance.isActive = true;
-        TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount = 0;
-
-        Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance());
+        Testable testTargetClass = (Testable) (loader.testClass.getConstructor().newInstance());
 
         // This overload is not instrumented, so it will not throw
         testTargetClass.someMethod(123);
-        assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
+        expectThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
 
         assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount);
     }
 
-    public void testInstrumenterWorksWithConstructors() throws Exception {
-        var classToInstrument = TestClassToInstrument.class;
-
-        Map<MethodKey, CheckMethod> checkMethods = Map.of(
-            new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
-            getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
-            new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
-            getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
-        );
-
-        var instrumenter = createInstrumenter(checkMethods);
-
-        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
-
-        if (logger.isTraceEnabled()) {
-            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
-        }
-
-        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
-            classToInstrument.getName() + "_NEW",
-            newBytecode
-        );
-
-        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
-
-        var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
-        assertThat(ex.getCause(), instanceOf(TestException.class));
-        var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
-        assertThat(ex2.getCause(), instanceOf(TestException.class));
+    public void testConstructors() throws Exception {
+        Constructor<?> ctor1 = TestClassToInstrument.class.getConstructor();
+        Constructor<?> ctor2 = TestClassToInstrument.class.getConstructor(int.class);
+        var loader = instrumentTestClass(createInstrumenter(Map.of("checkCtor", ctor1, "checkCtorOverload", ctor2)));
 
+        assertCtorThrows(loader, ctor1);
+        assertCtorThrows(loader, ctor2, 123);
         assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorCallCount);
         assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorIntCallCount);
     }
@@ -373,11 +290,107 @@ public class InstrumenterTests extends ESTestCase {
      * MethodKey and instrumentationMethod with slightly different signatures (using the common interface
      * Testable) which is not what would happen when it's run by the agent.
      */
-    private InstrumenterImpl createInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
+    private static InstrumenterImpl createInstrumenter(Map<String, Executable> methods) throws NoSuchMethodException {
+        Map<MethodKey, CheckMethod> checkMethods = new HashMap<>();
+        for (var entry : methods.entrySet()) {
+            checkMethods.put(getMethodKey(entry.getValue()), getCheckMethod(entry.getKey(), entry.getValue()));
+        }
         String checkerClass = Type.getInternalName(InstrumenterTests.MockEntitlementChecker.class);
         String handleClass = Type.getInternalName(InstrumenterTests.TestEntitlementCheckerHolder.class);
         String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
 
-        return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "_NEW", checkMethods);
+        return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
+    }
+
+    private static TestLoader instrumentTestClass(InstrumenterImpl instrumenter) throws IOException {
+        var clazz = TestClassToInstrument.class;
+        ClassFileInfo initial = getClassFileInfo(clazz);
+        byte[] newBytecode = instrumenter.instrumentClass(Type.getInternalName(clazz), initial.bytecodes());
+        if (logger.isTraceEnabled()) {
+            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
+        }
+        return new TestLoader(clazz.getName(), newBytecode);
+    }
+
+    private static MethodKey getMethodKey(Executable method) {
+        logger.info("method key: {}", method.getName());
+        String methodName = method instanceof Constructor<?> ? "<init>" : method.getName();
+        return new MethodKey(
+            Type.getInternalName(method.getDeclaringClass()),
+            methodName,
+            Stream.of(method.getParameterTypes()).map(Type::getType).map(Type::getInternalName).toList()
+        );
+    }
+
+    private static CheckMethod getCheckMethod(String methodName, Executable targetMethod) throws NoSuchMethodException {
+        boolean isStatic = Modifier.isStatic(targetMethod.getModifiers());
+        boolean isInstance = isStatic == false && targetMethod instanceof Method;
+        int extraArgs = 1; // caller class
+        if (isInstance) {
+            ++extraArgs;
+        }
+        Class<?>[] targetParameterTypes = targetMethod.getParameterTypes();
+        Class<?>[] checkParameterTypes = new Class<?>[targetParameterTypes.length + extraArgs];
+        checkParameterTypes[0] = Class.class;
+        if (isInstance) {
+            checkParameterTypes[1] = Testable.class;
+        }
+        System.arraycopy(targetParameterTypes, 0, checkParameterTypes, extraArgs, targetParameterTypes.length);
+        var checkMethod = MockEntitlementChecker.class.getMethod(methodName, checkParameterTypes);
+        return new CheckMethod(
+            Type.getInternalName(MockEntitlementChecker.class),
+            checkMethod.getName(),
+            Arrays.stream(Type.getArgumentTypes(checkMethod)).map(Type::getDescriptor).toList()
+        );
+    }
+
+    private static void unwrapInvocationException(InvocationTargetException e) {
+        Throwable cause = e.getCause();
+        if (cause instanceof TestException n) {
+            // Sometimes we're expecting this one!
+            throw n;
+        } else {
+            throw new AssertionError(cause);
+        }
+    }
+
+    /**
+     * Calling a static method of a dynamically loaded class is significantly more cumbersome
+     * than calling a virtual method.
+     */
+    static void callStaticMethod(Method method, Object... args) {
+        try {
+            method.invoke(null, args);
+        } catch (InvocationTargetException e) {
+            unwrapInvocationException(e);
+        } catch (IllegalAccessException e) {
+            throw new AssertionError(e);
+        }
+    }
+
+    private void assertStaticMethodThrows(TestLoader loader, Method method, Object... args) {
+        Method testMethod = loader.getSameMethod(method);
+        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
+        expectThrows(TestException.class, () -> callStaticMethod(testMethod, args));
+    }
+
+    private void assertStaticMethod(TestLoader loader, Method method, Object... args) {
+        Method testMethod = loader.getSameMethod(method);
+        TestEntitlementCheckerHolder.checkerInstance.isActive = false;
+        callStaticMethod(testMethod, args);
+    }
+
+    private void assertCtorThrows(TestLoader loader, Constructor<?> ctor, Object... args) {
+        Constructor<?> testCtor = loader.getSameConstructor(ctor);
+        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
+        expectThrows(TestException.class, () -> {
+            try {
+                testCtor.newInstance(args);
+            } catch (InvocationTargetException e) {
+                unwrapInvocationException(e);
+            } catch (IllegalAccessException | InstantiationException e) {
+                throw new AssertionError(e);
+            }
+        });
     }
 }

+ 0 - 20
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestLoader.java

@@ -1,20 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the "Elastic License
- * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
- * Public License v 1"; you may not use this file except in compliance with, at
- * your election, the "Elastic License 2.0", the "GNU Affero General Public
- * License v3.0 only", or the "Server Side Public License, v 1".
- */
-
-package org.elasticsearch.entitlement.instrumentation.impl;
-
-class TestLoader extends ClassLoader {
-    TestLoader(ClassLoader parent) {
-        super(parent);
-    }
-
-    public Class<?> defineClassFromBytes(String name, byte[] bytes) {
-        return defineClass(name, bytes, 0, bytes.length);
-    }
-}

+ 0 - 81
libs/entitlement/asm-provider/src/test/java/org/elasticsearch/entitlement/instrumentation/impl/TestMethodUtils.java

@@ -1,81 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the "Elastic License
- * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
- * Public License v 1"; you may not use this file except in compliance with, at
- * your election, the "Elastic License 2.0", the "GNU Affero General Public
- * License v3.0 only", or the "Server Side Public License, v 1".
- */
-
-package org.elasticsearch.entitlement.instrumentation.impl;
-
-import org.elasticsearch.entitlement.instrumentation.CheckMethod;
-import org.elasticsearch.entitlement.instrumentation.MethodKey;
-import org.objectweb.asm.Type;
-
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
-import java.util.Arrays;
-import java.util.List;
-import java.util.stream.Stream;
-
-class TestMethodUtils {
-
-    /**
-     * @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline
-     */
-    static MethodKey methodKeyForTarget(Method targetMethod) {
-        Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod));
-        return new MethodKey(
-            Type.getInternalName(targetMethod.getDeclaringClass()),
-            targetMethod.getName(),
-            Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList()
-        );
-    }
-
-    static MethodKey methodKeyForConstructor(Class<?> classToInstrument, List<String> params) {
-        return new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", params);
-    }
-
-    static CheckMethod getCheckMethod(Class<?> clazz, String methodName, Class<?>... parameterTypes) throws NoSuchMethodException {
-        var method = clazz.getMethod(methodName, parameterTypes);
-        return new CheckMethod(
-            Type.getInternalName(clazz),
-            method.getName(),
-            Arrays.stream(Type.getArgumentTypes(method)).map(Type::getDescriptor).toList()
-        );
-    }
-
-    /**
-     * Calling a static method of a dynamically loaded class is significantly more cumbersome
-     * than calling a virtual method.
-     */
-    static void callStaticMethod(Class<?> c, String methodName, int arg) throws NoSuchMethodException, IllegalAccessException {
-        try {
-            c.getMethod(methodName, int.class).invoke(null, arg);
-        } catch (InvocationTargetException e) {
-            Throwable cause = e.getCause();
-            if (cause instanceof TestException n) {
-                // Sometimes we're expecting this one!
-                throw n;
-            } else {
-                throw new AssertionError(cause);
-            }
-        }
-    }
-
-    static void callStaticMethod(Class<?> c, String methodName, int arg1, String arg2) throws NoSuchMethodException,
-        IllegalAccessException {
-        try {
-            c.getMethod(methodName, int.class, String.class).invoke(null, arg1, arg2);
-        } catch (InvocationTargetException e) {
-            Throwable cause = e.getCause();
-            if (cause instanceof TestException n) {
-                // Sometimes we're expecting this one!
-                throw n;
-            } else {
-                throw new AssertionError(cause);
-            }
-        }
-    }
-}

+ 4 - 4
libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java

@@ -14,6 +14,7 @@ import org.elasticsearch.entitlement.bootstrap.EntitlementBootstrap;
 import org.elasticsearch.entitlement.bridge.EntitlementChecker;
 import org.elasticsearch.entitlement.instrumentation.CheckMethod;
 import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
+import org.elasticsearch.entitlement.instrumentation.Instrumenter;
 import org.elasticsearch.entitlement.instrumentation.MethodKey;
 import org.elasticsearch.entitlement.instrumentation.Transformer;
 import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker;
@@ -65,13 +66,12 @@ public class EntitlementInitialization {
     public static void initialize(Instrumentation inst) throws Exception {
         manager = initChecker();
 
-        Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethodsToInstrument(
-            "org.elasticsearch.entitlement.bridge.EntitlementChecker"
-        );
+        Map<MethodKey, CheckMethod> checkMethods = INSTRUMENTER_FACTORY.lookupMethods(EntitlementChecker.class);
 
         var classesToTransform = checkMethods.keySet().stream().map(MethodKey::className).collect(Collectors.toSet());
 
-        inst.addTransformer(new Transformer(INSTRUMENTER_FACTORY.newInstrumenter(checkMethods), classesToTransform), true);
+        Instrumenter instrumenter = INSTRUMENTER_FACTORY.newInstrumenter(EntitlementChecker.class, checkMethods);
+        inst.addTransformer(new Transformer(instrumenter, classesToTransform), true);
         // TODO: should we limit this array somehow?
         var classesToRetransform = classesToTransform.stream().map(EntitlementInitialization::internalNameToClass).toArray(Class[]::new);
         inst.retransformClasses(classesToRetransform);

+ 2 - 2
libs/entitlement/src/main/java/org/elasticsearch/entitlement/instrumentation/InstrumentationService.java

@@ -16,7 +16,7 @@ import java.util.Map;
  * The SPI service entry point for instrumentation.
  */
 public interface InstrumentationService {
-    Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods);
+    Instrumenter newInstrumenter(Class<?> clazz, Map<MethodKey, CheckMethod> methods);
 
-    Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException, IOException;
+    Map<MethodKey, CheckMethod> lookupMethods(Class<?> clazz) throws IOException;
 }