|  | @@ -12,31 +12,63 @@ package org.elasticsearch.entitlement.instrumentation.impl;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.Strings;
 | 
	
		
			
				|  |  |  import org.elasticsearch.entitlement.instrumentation.CheckMethod;
 | 
	
		
			
				|  |  |  import org.elasticsearch.entitlement.instrumentation.MethodKey;
 | 
	
		
			
				|  |  | +import org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.ClassFileInfo;
 | 
	
		
			
				|  |  |  import org.elasticsearch.logging.LogManager;
 | 
	
		
			
				|  |  |  import org.elasticsearch.logging.Logger;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.ESTestCase;
 | 
	
		
			
				|  |  | +import org.junit.Before;
 | 
	
		
			
				|  |  |  import org.objectweb.asm.Type;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import java.io.IOException;
 | 
	
		
			
				|  |  | +import java.lang.reflect.Constructor;
 | 
	
		
			
				|  |  | +import java.lang.reflect.Executable;
 | 
	
		
			
				|  |  |  import java.lang.reflect.InvocationTargetException;
 | 
	
		
			
				|  |  | -import java.util.List;
 | 
	
		
			
				|  |  | +import java.lang.reflect.Method;
 | 
	
		
			
				|  |  | +import java.lang.reflect.Modifier;
 | 
	
		
			
				|  |  | +import java.util.Arrays;
 | 
	
		
			
				|  |  | +import java.util.HashMap;
 | 
	
		
			
				|  |  |  import java.util.Map;
 | 
	
		
			
				|  |  | +import java.util.stream.Stream;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static org.elasticsearch.entitlement.instrumentation.impl.ASMUtils.bytecode2text;
 | 
	
		
			
				|  |  |  import static org.elasticsearch.entitlement.instrumentation.impl.InstrumenterImpl.getClassFileInfo;
 | 
	
		
			
				|  |  | -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.callStaticMethod;
 | 
	
		
			
				|  |  | -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.getCheckMethod;
 | 
	
		
			
				|  |  | -import static org.elasticsearch.entitlement.instrumentation.impl.TestMethodUtils.methodKeyForTarget;
 | 
	
		
			
				|  |  | -import static org.hamcrest.Matchers.instanceOf;
 | 
	
		
			
				|  |  | -import static org.hamcrest.Matchers.startsWith;
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.equalTo;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  /**
 | 
	
		
			
				|  |  | - * This tests {@link InstrumenterImpl} with some ad-hoc instrumented method and checker methods, to allow us to check
 | 
	
		
			
				|  |  | - * some ad-hoc test cases (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
 | 
	
		
			
				|  |  | + * This tests {@link InstrumenterImpl} can instrument various method signatures
 | 
	
		
			
				|  |  | + * (e.g. overloaded methods, overloaded targets, multiple instrumentation, etc.)
 | 
	
		
			
				|  |  |   */
 | 
	
		
			
				|  |  |  @ESTestCase.WithoutSecurityManager
 | 
	
		
			
				|  |  |  public class InstrumenterTests extends ESTestCase {
 | 
	
		
			
				|  |  |      private static final Logger logger = LogManager.getLogger(InstrumenterTests.class);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    static class TestLoader extends ClassLoader {
 | 
	
		
			
				|  |  | +        final byte[] testClassBytes;
 | 
	
		
			
				|  |  | +        final Class<?> testClass;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        TestLoader(String testClassName, byte[] testClassBytes) {
 | 
	
		
			
				|  |  | +            super(InstrumenterTests.class.getClassLoader());
 | 
	
		
			
				|  |  | +            this.testClassBytes = testClassBytes;
 | 
	
		
			
				|  |  | +            this.testClass = defineClass(testClassName, testClassBytes, 0, testClassBytes.length);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Method getSameMethod(Method method) {
 | 
	
		
			
				|  |  | +            try {
 | 
	
		
			
				|  |  | +                return testClass.getMethod(method.getName(), method.getParameterTypes());
 | 
	
		
			
				|  |  | +            } catch (NoSuchMethodException e) {
 | 
	
		
			
				|  |  | +                throw new AssertionError(e);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Constructor<?> getSameConstructor(Constructor<?> ctor) {
 | 
	
		
			
				|  |  | +            try {
 | 
	
		
			
				|  |  | +                return testClass.getConstructor(ctor.getParameterTypes());
 | 
	
		
			
				|  |  | +            } catch (NoSuchMethodException e) {
 | 
	
		
			
				|  |  | +                throw new AssertionError(e);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      /**
 | 
	
		
			
				|  |  |       * Contains all the virtual methods from {@link TestClassToInstrument},
 | 
	
		
			
				|  |  |       * allowing this test to call them on the dynamically loaded instrumented class.
 | 
	
	
		
			
				|  | @@ -80,13 +112,15 @@ public class InstrumenterTests extends ESTestCase {
 | 
	
		
			
				|  |  |      public interface MockEntitlementChecker {
 | 
	
		
			
				|  |  |          void checkSomeStaticMethod(Class<?> clazz, int arg);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        void checkSomeStaticMethod(Class<?> clazz, int arg, String anotherArg);
 | 
	
		
			
				|  |  | +        void checkSomeStaticMethodOverload(Class<?> clazz, int arg, String anotherArg);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        void checkAnotherStaticMethod(Class<?> clazz, int arg);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          void checkSomeInstanceMethod(Class<?> clazz, Testable that, int arg, String anotherArg);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          void checkCtor(Class<?> clazz);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        void checkCtor(Class<?> clazz, int arg);
 | 
	
		
			
				|  |  | +        void checkCtorOverload(Class<?> clazz, int arg);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public static class TestEntitlementCheckerHolder {
 | 
	
	
		
			
				|  | @@ -105,6 +139,7 @@ public class InstrumenterTests extends ESTestCase {
 | 
	
		
			
				|  |  |          volatile boolean isActive;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          int checkSomeStaticMethodIntCallCount = 0;
 | 
	
		
			
				|  |  | +        int checkAnotherStaticMethodIntCallCount = 0;
 | 
	
		
			
				|  |  |          int checkSomeStaticMethodIntStringCallCount = 0;
 | 
	
		
			
				|  |  |          int checkSomeInstanceMethodCallCount = 0;
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -120,28 +155,33 @@ public class InstrumenterTests extends ESTestCase {
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  |          public void checkSomeStaticMethod(Class<?> callerClass, int arg) {
 | 
	
		
			
				|  |  |              checkSomeStaticMethodIntCallCount++;
 | 
	
		
			
				|  |  | -            assertSame(TestMethodUtils.class, callerClass);
 | 
	
		
			
				|  |  | +            assertSame(InstrumenterTests.class, callerClass);
 | 
	
		
			
				|  |  |              assertEquals(123, arg);
 | 
	
		
			
				|  |  |              throwIfActive();
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  | -        public void checkSomeStaticMethod(Class<?> callerClass, int arg, String anotherArg) {
 | 
	
		
			
				|  |  | +        public void checkSomeStaticMethodOverload(Class<?> callerClass, int arg, String anotherArg) {
 | 
	
		
			
				|  |  |              checkSomeStaticMethodIntStringCallCount++;
 | 
	
		
			
				|  |  | -            assertSame(TestMethodUtils.class, callerClass);
 | 
	
		
			
				|  |  | +            assertSame(InstrumenterTests.class, callerClass);
 | 
	
		
			
				|  |  |              assertEquals(123, arg);
 | 
	
		
			
				|  |  |              assertEquals("abc", anotherArg);
 | 
	
		
			
				|  |  |              throwIfActive();
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        @Override
 | 
	
		
			
				|  |  | +        public void checkAnotherStaticMethod(Class<?> callerClass, int arg) {
 | 
	
		
			
				|  |  | +            checkAnotherStaticMethodIntCallCount++;
 | 
	
		
			
				|  |  | +            assertSame(InstrumenterTests.class, callerClass);
 | 
	
		
			
				|  |  | +            assertEquals(123, arg);
 | 
	
		
			
				|  |  | +            throwIfActive();
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  |          public void checkSomeInstanceMethod(Class<?> callerClass, Testable that, int arg, String anotherArg) {
 | 
	
		
			
				|  |  |              checkSomeInstanceMethodCallCount++;
 | 
	
		
			
				|  |  |              assertSame(InstrumenterTests.class, callerClass);
 | 
	
		
			
				|  |  | -            assertThat(
 | 
	
		
			
				|  |  | -                that.getClass().getName(),
 | 
	
		
			
				|  |  | -                startsWith("org.elasticsearch.entitlement.instrumentation.impl.InstrumenterTests$TestClassToInstrument")
 | 
	
		
			
				|  |  | -            );
 | 
	
		
			
				|  |  | +            assertThat(that.getClass().getName(), equalTo(TestClassToInstrument.class.getName()));
 | 
	
		
			
				|  |  |              assertEquals(123, arg);
 | 
	
		
			
				|  |  |              assertEquals("def", anotherArg);
 | 
	
		
			
				|  |  |              throwIfActive();
 | 
	
	
		
			
				|  | @@ -155,7 +195,7 @@ public class InstrumenterTests extends ESTestCase {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  | -        public void checkCtor(Class<?> callerClass, int arg) {
 | 
	
		
			
				|  |  | +        public void checkCtorOverload(Class<?> callerClass, int arg) {
 | 
	
		
			
				|  |  |              checkCtorIntCallCount++;
 | 
	
		
			
				|  |  |              assertSame(InstrumenterTests.class, callerClass);
 | 
	
		
			
				|  |  |              assertEquals(123, arg);
 | 
	
	
		
			
				|  | @@ -163,206 +203,83 @@ public class InstrumenterTests extends ESTestCase {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testClassIsInstrumented() throws Exception {
 | 
	
		
			
				|  |  | -        var classToInstrument = TestClassToInstrument.class;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
 | 
	
		
			
				|  |  | -        Map<MethodKey, CheckMethod> checkMethods = Map.of(
 | 
	
		
			
				|  |  | -            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
 | 
	
		
			
				|  |  | -            checkMethod
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var instrumenter = createInstrumenter(checkMethods);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        if (logger.isTraceEnabled()) {
 | 
	
		
			
				|  |  | -            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
 | 
	
		
			
				|  |  | -            classToInstrument.getName() + "_NEW",
 | 
	
		
			
				|  |  | -            newBytecode
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | +    @Before
 | 
	
		
			
				|  |  | +    public void resetInstance() {
 | 
	
		
			
				|  |  | +        TestEntitlementCheckerHolder.checkerInstance = new TestEntitlementChecker();
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.isActive = false;
 | 
	
		
			
				|  |  | +    public void testStaticMethod() throws Exception {
 | 
	
		
			
				|  |  | +        Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
 | 
	
		
			
				|  |  | +        TestLoader loader = instrumentTestClass(createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod)));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          // Before checking is active, nothing should throw
 | 
	
		
			
				|  |  | -        callStaticMethod(newClass, "someStaticMethod", 123);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        assertStaticMethod(loader, targetMethod, 123);
 | 
	
		
			
				|  |  |          // After checking is activated, everything should throw
 | 
	
		
			
				|  |  | -        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
 | 
	
		
			
				|  |  | +        assertStaticMethodThrows(loader, targetMethod, 123);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testClassIsNotInstrumentedTwice() throws Exception {
 | 
	
		
			
				|  |  | -        var classToInstrument = TestClassToInstrument.class;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
 | 
	
		
			
				|  |  | -        Map<MethodKey, CheckMethod> checkMethods = Map.of(
 | 
	
		
			
				|  |  | -            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
 | 
	
		
			
				|  |  | -            checkMethod
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var instrumenter = createInstrumenter(checkMethods);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
 | 
	
		
			
				|  |  | -        var internalClassName = Type.getInternalName(classToInstrument);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
 | 
	
		
			
				|  |  | -        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
 | 
	
		
			
				|  |  | +    public void testNotInstrumentedTwice() throws Exception {
 | 
	
		
			
				|  |  | +        Method targetMethod = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
 | 
	
		
			
				|  |  | +        var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
 | 
	
		
			
				|  |  | +        var loader1 = instrumentTestClass(instrumenter);
 | 
	
		
			
				|  |  | +        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(TestClassToInstrument.class.getName(), loader1.testClassBytes);
 | 
	
		
			
				|  |  |          logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
 | 
	
		
			
				|  |  | +        var loader2 = new TestLoader(TestClassToInstrument.class.getName(), instrumentedTwiceBytecode);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
 | 
	
		
			
				|  |  | -            classToInstrument.getName() + "_NEW_NEW",
 | 
	
		
			
				|  |  | -            instrumentedTwiceBytecode
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
 | 
	
		
			
				|  |  | +        assertStaticMethodThrows(loader2, targetMethod, 123);
 | 
	
		
			
				|  |  |          assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testClassAllMethodsAreInstrumentedFirstPass() throws Exception {
 | 
	
		
			
				|  |  | -        var classToInstrument = TestClassToInstrument.class;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        CheckMethod checkMethod = getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class);
 | 
	
		
			
				|  |  | -        Map<MethodKey, CheckMethod> checkMethods = Map.of(
 | 
	
		
			
				|  |  | -            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
 | 
	
		
			
				|  |  | -            checkMethod,
 | 
	
		
			
				|  |  | -            methodKeyForTarget(classToInstrument.getMethod("anotherStaticMethod", int.class)),
 | 
	
		
			
				|  |  | -            checkMethod
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var instrumenter = createInstrumenter(checkMethods);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        InstrumenterImpl.ClassFileInfo initial = getClassFileInfo(classToInstrument);
 | 
	
		
			
				|  |  | -        var internalClassName = Type.getInternalName(classToInstrument);
 | 
	
		
			
				|  |  | +    public void testMultipleMethods() throws Exception {
 | 
	
		
			
				|  |  | +        Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
 | 
	
		
			
				|  |  | +        Method targetMethod2 = TestClassToInstrument.class.getMethod("anotherStaticMethod", int.class);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        byte[] instrumentedBytecode = instrumenter.instrumentClass(internalClassName, initial.bytecodes());
 | 
	
		
			
				|  |  | -        byte[] instrumentedTwiceBytecode = instrumenter.instrumentClass(internalClassName, instrumentedBytecode);
 | 
	
		
			
				|  |  | +        var instrumenter = createInstrumenter(Map.of("checkSomeStaticMethod", targetMethod1, "checkAnotherStaticMethod", targetMethod2));
 | 
	
		
			
				|  |  | +        var loader = instrumentTestClass(instrumenter);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        logger.trace(() -> Strings.format("Bytecode after 1st instrumentation:\n%s", bytecode2text(instrumentedBytecode)));
 | 
	
		
			
				|  |  | -        logger.trace(() -> Strings.format("Bytecode after 2nd instrumentation:\n%s", bytecode2text(instrumentedTwiceBytecode)));
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
 | 
	
		
			
				|  |  | -            classToInstrument.getName() + "_NEW_NEW",
 | 
	
		
			
				|  |  | -            instrumentedTwiceBytecode
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
 | 
	
		
			
				|  |  | +        assertStaticMethodThrows(loader, targetMethod1, 123);
 | 
	
		
			
				|  |  |          assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        assertThrows(TestException.class, () -> callStaticMethod(newClass, "anotherStaticMethod", 123));
 | 
	
		
			
				|  |  | -        assertEquals(2, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
 | 
	
		
			
				|  |  | +        assertStaticMethodThrows(loader, targetMethod2, 123);
 | 
	
		
			
				|  |  | +        assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkAnotherStaticMethodIntCallCount);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testInstrumenterWorksWithOverloads() throws Exception {
 | 
	
		
			
				|  |  | -        var classToInstrument = TestClassToInstrument.class;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Map<MethodKey, CheckMethod> checkMethods = Map.of(
 | 
	
		
			
				|  |  | -            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class)),
 | 
	
		
			
				|  |  | -            getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class),
 | 
	
		
			
				|  |  | -            methodKeyForTarget(classToInstrument.getMethod("someStaticMethod", int.class, String.class)),
 | 
	
		
			
				|  |  | -            getCheckMethod(MockEntitlementChecker.class, "checkSomeStaticMethod", Class.class, int.class, String.class)
 | 
	
		
			
				|  |  | +    public void testStaticMethodOverload() throws Exception {
 | 
	
		
			
				|  |  | +        Method targetMethod1 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class);
 | 
	
		
			
				|  |  | +        Method targetMethod2 = TestClassToInstrument.class.getMethod("someStaticMethod", int.class, String.class);
 | 
	
		
			
				|  |  | +        var instrumenter = createInstrumenter(
 | 
	
		
			
				|  |  | +            Map.of("checkSomeStaticMethod", targetMethod1, "checkSomeStaticMethodOverload", targetMethod2)
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  | +        var loader = instrumentTestClass(instrumenter);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        var instrumenter = createInstrumenter(checkMethods);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        if (logger.isTraceEnabled()) {
 | 
	
		
			
				|  |  | -            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
 | 
	
		
			
				|  |  | -            classToInstrument.getName() + "_NEW",
 | 
	
		
			
				|  |  | -            newBytecode
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount = 0;
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount = 0;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        // After checking is activated, everything should throw
 | 
	
		
			
				|  |  | -        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123));
 | 
	
		
			
				|  |  | -        assertThrows(TestException.class, () -> callStaticMethod(newClass, "someStaticMethod", 123, "abc"));
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        assertStaticMethodThrows(loader, targetMethod1, 123);
 | 
	
		
			
				|  |  | +        assertStaticMethodThrows(loader, targetMethod2, 123, "abc");
 | 
	
		
			
				|  |  |          assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntCallCount);
 | 
	
		
			
				|  |  |          assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeStaticMethodIntStringCallCount);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testInstrumenterWorksWithInstanceMethodsAndOverloads() throws Exception {
 | 
	
		
			
				|  |  | -        var classToInstrument = TestClassToInstrument.class;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Map<MethodKey, CheckMethod> checkMethods = Map.of(
 | 
	
		
			
				|  |  | -            methodKeyForTarget(classToInstrument.getMethod("someMethod", int.class, String.class)),
 | 
	
		
			
				|  |  | -            getCheckMethod(MockEntitlementChecker.class, "checkSomeInstanceMethod", Class.class, Testable.class, int.class, String.class)
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var instrumenter = createInstrumenter(checkMethods);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        if (logger.isTraceEnabled()) {
 | 
	
		
			
				|  |  | -            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
 | 
	
		
			
				|  |  | -            classToInstrument.getName() + "_NEW",
 | 
	
		
			
				|  |  | -            newBytecode
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | +    public void testInstanceMethodOverload() throws Exception {
 | 
	
		
			
				|  |  | +        Method targetMethod = TestClassToInstrument.class.getMethod("someMethod", int.class, String.class);
 | 
	
		
			
				|  |  | +        var instrumenter = createInstrumenter(Map.of("checkSomeInstanceMethod", targetMethod));
 | 
	
		
			
				|  |  | +        var loader = instrumentTestClass(instrumenter);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          TestEntitlementCheckerHolder.checkerInstance.isActive = true;
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount = 0;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Testable testTargetClass = (Testable) (newClass.getConstructor().newInstance());
 | 
	
		
			
				|  |  | +        Testable testTargetClass = (Testable) (loader.testClass.getConstructor().newInstance());
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          // This overload is not instrumented, so it will not throw
 | 
	
		
			
				|  |  |          testTargetClass.someMethod(123);
 | 
	
		
			
				|  |  | -        assertThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
 | 
	
		
			
				|  |  | +        expectThrows(TestException.class, () -> testTargetClass.someMethod(123, "def"));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkSomeInstanceMethodCallCount);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testInstrumenterWorksWithConstructors() throws Exception {
 | 
	
		
			
				|  |  | -        var classToInstrument = TestClassToInstrument.class;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Map<MethodKey, CheckMethod> checkMethods = Map.of(
 | 
	
		
			
				|  |  | -            new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of()),
 | 
	
		
			
				|  |  | -            getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class),
 | 
	
		
			
				|  |  | -            new MethodKey(classToInstrument.getName().replace('.', '/'), "<init>", List.of("I")),
 | 
	
		
			
				|  |  | -            getCheckMethod(MockEntitlementChecker.class, "checkCtor", Class.class, int.class)
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var instrumenter = createInstrumenter(checkMethods);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        byte[] newBytecode = instrumenter.instrumentClassFile(classToInstrument).bytecodes();
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        if (logger.isTraceEnabled()) {
 | 
	
		
			
				|  |  | -            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        Class<?> newClass = new TestLoader(Testable.class.getClassLoader()).defineClassFromBytes(
 | 
	
		
			
				|  |  | -            classToInstrument.getName() + "_NEW",
 | 
	
		
			
				|  |  | -            newBytecode
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var ex = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor().newInstance());
 | 
	
		
			
				|  |  | -        assertThat(ex.getCause(), instanceOf(TestException.class));
 | 
	
		
			
				|  |  | -        var ex2 = assertThrows(InvocationTargetException.class, () -> newClass.getConstructor(int.class).newInstance(123));
 | 
	
		
			
				|  |  | -        assertThat(ex2.getCause(), instanceOf(TestException.class));
 | 
	
		
			
				|  |  | +    public void testConstructors() throws Exception {
 | 
	
		
			
				|  |  | +        Constructor<?> ctor1 = TestClassToInstrument.class.getConstructor();
 | 
	
		
			
				|  |  | +        Constructor<?> ctor2 = TestClassToInstrument.class.getConstructor(int.class);
 | 
	
		
			
				|  |  | +        var loader = instrumentTestClass(createInstrumenter(Map.of("checkCtor", ctor1, "checkCtorOverload", ctor2)));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        assertCtorThrows(loader, ctor1);
 | 
	
		
			
				|  |  | +        assertCtorThrows(loader, ctor2, 123);
 | 
	
		
			
				|  |  |          assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorCallCount);
 | 
	
		
			
				|  |  |          assertEquals(1, TestEntitlementCheckerHolder.checkerInstance.checkCtorIntCallCount);
 | 
	
		
			
				|  |  |      }
 | 
	
	
		
			
				|  | @@ -373,11 +290,107 @@ public class InstrumenterTests extends ESTestCase {
 | 
	
		
			
				|  |  |       * MethodKey and instrumentationMethod with slightly different signatures (using the common interface
 | 
	
		
			
				|  |  |       * Testable) which is not what would happen when it's run by the agent.
 | 
	
		
			
				|  |  |       */
 | 
	
		
			
				|  |  | -    private InstrumenterImpl createInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
 | 
	
		
			
				|  |  | +    private static InstrumenterImpl createInstrumenter(Map<String, Executable> methods) throws NoSuchMethodException {
 | 
	
		
			
				|  |  | +        Map<MethodKey, CheckMethod> checkMethods = new HashMap<>();
 | 
	
		
			
				|  |  | +        for (var entry : methods.entrySet()) {
 | 
	
		
			
				|  |  | +            checkMethods.put(getMethodKey(entry.getValue()), getCheckMethod(entry.getKey(), entry.getValue()));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |          String checkerClass = Type.getInternalName(InstrumenterTests.MockEntitlementChecker.class);
 | 
	
		
			
				|  |  |          String handleClass = Type.getInternalName(InstrumenterTests.TestEntitlementCheckerHolder.class);
 | 
	
		
			
				|  |  |          String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "_NEW", checkMethods);
 | 
	
		
			
				|  |  | +        return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private static TestLoader instrumentTestClass(InstrumenterImpl instrumenter) throws IOException {
 | 
	
		
			
				|  |  | +        var clazz = TestClassToInstrument.class;
 | 
	
		
			
				|  |  | +        ClassFileInfo initial = getClassFileInfo(clazz);
 | 
	
		
			
				|  |  | +        byte[] newBytecode = instrumenter.instrumentClass(Type.getInternalName(clazz), initial.bytecodes());
 | 
	
		
			
				|  |  | +        if (logger.isTraceEnabled()) {
 | 
	
		
			
				|  |  | +            logger.trace("Bytecode after instrumentation:\n{}", bytecode2text(newBytecode));
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        return new TestLoader(clazz.getName(), newBytecode);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private static MethodKey getMethodKey(Executable method) {
 | 
	
		
			
				|  |  | +        logger.info("method key: {}", method.getName());
 | 
	
		
			
				|  |  | +        String methodName = method instanceof Constructor<?> ? "<init>" : method.getName();
 | 
	
		
			
				|  |  | +        return new MethodKey(
 | 
	
		
			
				|  |  | +            Type.getInternalName(method.getDeclaringClass()),
 | 
	
		
			
				|  |  | +            methodName,
 | 
	
		
			
				|  |  | +            Stream.of(method.getParameterTypes()).map(Type::getType).map(Type::getInternalName).toList()
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private static CheckMethod getCheckMethod(String methodName, Executable targetMethod) throws NoSuchMethodException {
 | 
	
		
			
				|  |  | +        boolean isStatic = Modifier.isStatic(targetMethod.getModifiers());
 | 
	
		
			
				|  |  | +        boolean isInstance = isStatic == false && targetMethod instanceof Method;
 | 
	
		
			
				|  |  | +        int extraArgs = 1; // caller class
 | 
	
		
			
				|  |  | +        if (isInstance) {
 | 
	
		
			
				|  |  | +            ++extraArgs;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        Class<?>[] targetParameterTypes = targetMethod.getParameterTypes();
 | 
	
		
			
				|  |  | +        Class<?>[] checkParameterTypes = new Class<?>[targetParameterTypes.length + extraArgs];
 | 
	
		
			
				|  |  | +        checkParameterTypes[0] = Class.class;
 | 
	
		
			
				|  |  | +        if (isInstance) {
 | 
	
		
			
				|  |  | +            checkParameterTypes[1] = Testable.class;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        System.arraycopy(targetParameterTypes, 0, checkParameterTypes, extraArgs, targetParameterTypes.length);
 | 
	
		
			
				|  |  | +        var checkMethod = MockEntitlementChecker.class.getMethod(methodName, checkParameterTypes);
 | 
	
		
			
				|  |  | +        return new CheckMethod(
 | 
	
		
			
				|  |  | +            Type.getInternalName(MockEntitlementChecker.class),
 | 
	
		
			
				|  |  | +            checkMethod.getName(),
 | 
	
		
			
				|  |  | +            Arrays.stream(Type.getArgumentTypes(checkMethod)).map(Type::getDescriptor).toList()
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private static void unwrapInvocationException(InvocationTargetException e) {
 | 
	
		
			
				|  |  | +        Throwable cause = e.getCause();
 | 
	
		
			
				|  |  | +        if (cause instanceof TestException n) {
 | 
	
		
			
				|  |  | +            // Sometimes we're expecting this one!
 | 
	
		
			
				|  |  | +            throw n;
 | 
	
		
			
				|  |  | +        } else {
 | 
	
		
			
				|  |  | +            throw new AssertionError(cause);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /**
 | 
	
		
			
				|  |  | +     * Calling a static method of a dynamically loaded class is significantly more cumbersome
 | 
	
		
			
				|  |  | +     * than calling a virtual method.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    static void callStaticMethod(Method method, Object... args) {
 | 
	
		
			
				|  |  | +        try {
 | 
	
		
			
				|  |  | +            method.invoke(null, args);
 | 
	
		
			
				|  |  | +        } catch (InvocationTargetException e) {
 | 
	
		
			
				|  |  | +            unwrapInvocationException(e);
 | 
	
		
			
				|  |  | +        } catch (IllegalAccessException e) {
 | 
	
		
			
				|  |  | +            throw new AssertionError(e);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private void assertStaticMethodThrows(TestLoader loader, Method method, Object... args) {
 | 
	
		
			
				|  |  | +        Method testMethod = loader.getSameMethod(method);
 | 
	
		
			
				|  |  | +        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
 | 
	
		
			
				|  |  | +        expectThrows(TestException.class, () -> callStaticMethod(testMethod, args));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private void assertStaticMethod(TestLoader loader, Method method, Object... args) {
 | 
	
		
			
				|  |  | +        Method testMethod = loader.getSameMethod(method);
 | 
	
		
			
				|  |  | +        TestEntitlementCheckerHolder.checkerInstance.isActive = false;
 | 
	
		
			
				|  |  | +        callStaticMethod(testMethod, args);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private void assertCtorThrows(TestLoader loader, Constructor<?> ctor, Object... args) {
 | 
	
		
			
				|  |  | +        Constructor<?> testCtor = loader.getSameConstructor(ctor);
 | 
	
		
			
				|  |  | +        TestEntitlementCheckerHolder.checkerInstance.isActive = true;
 | 
	
		
			
				|  |  | +        expectThrows(TestException.class, () -> {
 | 
	
		
			
				|  |  | +            try {
 | 
	
		
			
				|  |  | +                testCtor.newInstance(args);
 | 
	
		
			
				|  |  | +            } catch (InvocationTargetException e) {
 | 
	
		
			
				|  |  | +                unwrapInvocationException(e);
 | 
	
		
			
				|  |  | +            } catch (IllegalAccessException | InstantiationException e) {
 | 
	
		
			
				|  |  | +                throw new AssertionError(e);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  }
 |