Browse Source

Scripting: Reflect factory signatures in painless classloader (#34088)

It is sometimes desirable to pass a class into a script constructor that
will not actually be exposed in the script whitelist. This commit uses
reflection when creating the compiler to find all the classes of the
factory method signature, and make the classloader that wraps lookup
also expose these classes.
Ryan Ernst 7 years ago
parent
commit
85e4ef3429

+ 40 - 20
modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java

@@ -28,11 +28,14 @@ import org.elasticsearch.painless.spi.Whitelist;
 import org.objectweb.asm.util.Printer;
 
 import java.lang.reflect.Constructor;
+import java.lang.reflect.Method;
 import java.net.MalformedURLException;
 import java.net.URL;
 import java.security.CodeSource;
 import java.security.SecureClassLoader;
 import java.security.cert.Certificate;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -89,16 +92,11 @@ final class Compiler {
          */
         @Override
         public Class<?> findClass(String name) throws ClassNotFoundException {
-            if (scriptClass.getName().equals(name)) {
-                return scriptClass;
+            Class<?> found = additionalClasses.get(name);
+            if (found != null) {
+                return found;
             }
-            if (factoryClass != null && factoryClass.getName().equals(name)) {
-                return factoryClass;
-            }
-            if (statefulFactoryClass != null && statefulFactoryClass.getName().equals(name)) {
-                return statefulFactoryClass;
-            }
-            Class<?> found = painlessLookup.canonicalTypeNameToType(name.replace('$', '.'));
+            found = painlessLookup.canonicalTypeNameToType(name.replace('$', '.'));
 
             return found != null ? found : super.findClass(name);
         }
@@ -156,19 +154,14 @@ final class Compiler {
     private final Class<?> scriptClass;
 
     /**
-     * The class/interface to create the {@code scriptClass} instance.
-     */
-    private final Class<?> factoryClass;
-
-    /**
-     * An optional class/interface to create the {@code factoryClass} instance.
+     * The whitelist the script will use.
      */
-    private final Class<?> statefulFactoryClass;
+    private final PainlessLookup painlessLookup;
 
     /**
-     * The whitelist the script will use.
+     * Classes that do not exist in the lookup, but are needed by the script factories.
      */
-    private final PainlessLookup painlessLookup;
+    private final Map<String, Class<?>> additionalClasses;
 
     /**
      * Standard constructor.
@@ -179,9 +172,36 @@ final class Compiler {
      */
     Compiler(Class<?> scriptClass, Class<?> factoryClass, Class<?> statefulFactoryClass, PainlessLookup painlessLookup) {
         this.scriptClass = scriptClass;
-        this.factoryClass = factoryClass;
-        this.statefulFactoryClass = statefulFactoryClass;
         this.painlessLookup = painlessLookup;
+        Map<String, Class<?>> additionalClasses = new HashMap<>();
+        additionalClasses.put(scriptClass.getName(), scriptClass);
+        addFactoryMethod(additionalClasses, factoryClass, "newInstance");
+        addFactoryMethod(additionalClasses, statefulFactoryClass, "newFactory");
+        addFactoryMethod(additionalClasses, statefulFactoryClass, "newInstance");
+        this.additionalClasses = Collections.unmodifiableMap(additionalClasses);
+    }
+
+    private static void addFactoryMethod(Map<String, Class<?>> additionalClasses, Class<?> factoryClass, String methodName) {
+        if (factoryClass == null) {
+            return;
+        }
+
+        Method factoryMethod = null;
+        for (Method method : factoryClass.getMethods()) {
+            if (methodName.equals(method.getName())) {
+                factoryMethod = method;
+                break;
+            }
+        }
+        if (factoryMethod == null) {
+            return;
+        }
+
+        additionalClasses.put(factoryClass.getName(), factoryClass);
+        for (int i = 0; i < factoryMethod.getParameterTypes().length; ++i) {
+            Class<?> parameterClazz = factoryMethod.getParameterTypes()[i];
+            additionalClasses.put(parameterClazz.getName(), parameterClazz);
+        }
     }
 
     /**