Browse Source

Support script context stateful factory in Painless. (#25233)

Jack Conradson 8 years ago
parent
commit
a4471f51e4

+ 128 - 9
modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessScriptEngine.java

@@ -43,6 +43,7 @@ import java.security.Permissions;
 import java.security.PrivilegedAction;
 import java.security.ProtectionDomain;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -158,20 +159,126 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
 
             compile(contextsToCompilers.get(context), loader, scriptName, scriptSource, params);
 
-            return generateFactory(loader, context);
+            if (context.statefulFactoryClazz != null) {
+                return generateFactory(loader, context, generateStatefulFactory(loader, context));
+            } else {
+                return generateFactory(loader, context, WriterConstants.CLASS_TYPE);
+            }
+        }
+    }
+
+    /**
+     * Generates a stateful factory class that will return script instances.  Acts as a middle man between
+     * the {@link ScriptContext#factoryClazz} and the {@link ScriptContext#instanceClazz} when used so that
+     * the stateless factory can be used for caching and the stateful factory can act as a cache for new
+     * script instances.  Uses the newInstance method from a {@link ScriptContext#statefulFactoryClazz} to
+     * define the factory method to create new instances of the {@link ScriptContext#instanceClazz}.
+     * @param loader The {@link ClassLoader} that is used to define the factory class and script class.
+     * @param context The {@link ScriptContext}'s semantics are used to define the factory class.
+     * @param <T> The factory class.
+     * @return A factory class that will return script instances.
+     */
+    private <T> Type generateStatefulFactory(Loader loader, ScriptContext<T> context) {
+        int classFrames = ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS;
+        int classAccess = Opcodes.ACC_PUBLIC | Opcodes.ACC_SUPER | Opcodes.ACC_FINAL;
+        String interfaceBase = Type.getType(context.statefulFactoryClazz).getInternalName();
+        String className = interfaceBase + "$StatefulFactory";
+        String classInterfaces[] = new String[] { interfaceBase };
+
+        ClassWriter writer = new ClassWriter(classFrames);
+        writer.visit(WriterConstants.CLASS_VERSION, classAccess, className, null, OBJECT_TYPE.getInternalName(), classInterfaces);
+
+        Method newFactory = null;
+
+        for (Method method : context.factoryClazz.getMethods()) {
+            if ("newFactory".equals(method.getName())) {
+                newFactory = method;
+
+                break;
+            }
+        }
+
+        for (int count = 0; count < newFactory.getParameterTypes().length; ++count) {
+            writer.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL, "$arg" + count,
+                Type.getType(newFactory.getParameterTypes()[count]).getDescriptor(), null, null).visitEnd();
+        }
+
+        org.objectweb.asm.commons.Method base =
+            new org.objectweb.asm.commons.Method("<init>", MethodType.methodType(void.class).toMethodDescriptorString());
+        org.objectweb.asm.commons.Method init = new org.objectweb.asm.commons.Method("<init>",
+            MethodType.methodType(void.class, newFactory.getParameterTypes()).toMethodDescriptorString());
+
+        GeneratorAdapter constructor = new GeneratorAdapter(Opcodes.ASM5, init,
+            writer.visitMethod(Opcodes.ACC_PUBLIC, init.getName(), init.getDescriptor(), null, null));
+        constructor.visitCode();
+        constructor.loadThis();
+        constructor.invokeConstructor(OBJECT_TYPE, base);
+
+        for (int count = 0; count < newFactory.getParameterTypes().length; ++count) {
+            constructor.loadThis();
+            constructor.loadArg(count);
+            constructor.putField(Type.getType(className), "$arg" + count, Type.getType(newFactory.getParameterTypes()[count]));
+        }
+
+        constructor.returnValue();
+        constructor.endMethod();
+
+        Method newInstance = null;
+
+        for (Method method : context.statefulFactoryClazz.getMethods()) {
+            if ("newInstance".equals(method.getName())) {
+                newInstance = method;
+
+                break;
+            }
+        }
+
+        org.objectweb.asm.commons.Method instance = new org.objectweb.asm.commons.Method(newInstance.getName(),
+            MethodType.methodType(newInstance.getReturnType(), newInstance.getParameterTypes()).toMethodDescriptorString());
+
+        List<Class<?>> parameters = new ArrayList<>(Arrays.asList(newFactory.getParameterTypes()));
+        parameters.addAll(Arrays.asList(newInstance.getParameterTypes()));
+
+        org.objectweb.asm.commons.Method constru = new org.objectweb.asm.commons.Method("<init>",
+            MethodType.methodType(void.class, parameters.toArray(new Class<?>[] {})).toMethodDescriptorString());
+
+        GeneratorAdapter adapter = new GeneratorAdapter(Opcodes.ASM5, instance,
+            writer.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL,
+                instance.getName(), instance.getDescriptor(), null, null));
+        adapter.visitCode();
+        adapter.newInstance(WriterConstants.CLASS_TYPE);
+        adapter.dup();
+
+        for (int count = 0; count < newFactory.getParameterTypes().length; ++count) {
+            adapter.loadThis();
+            adapter.getField(Type.getType(className), "$arg" + count, Type.getType(newFactory.getParameterTypes()[count]));
         }
+
+        adapter.loadArgs();
+        adapter.invokeConstructor(WriterConstants.CLASS_TYPE, constru);
+        adapter.returnValue();
+        adapter.endMethod();
+
+        writer.visitEnd();
+
+        loader.defineFactory(className.replace('/', '.'), writer.toByteArray());
+
+        return Type.getType(className);
     }
 
     /**
-     * Generates a factory class that will return script instances.
+     * Generates a factory class that will return script instances or stateful factories.
      * Uses the newInstance method from a {@link ScriptContext#factoryClazz} to define the factory method
-     * to create new instances of the {@link ScriptContext#instanceClazz}.
+     * to create new instances of the {@link ScriptContext#instanceClazz} or uses the newFactory method
+     * to create new factories of the {@link ScriptContext#statefulFactoryClazz}.
      * @param loader The {@link ClassLoader} that is used to define the factory class and script class.
      * @param context The {@link ScriptContext}'s semantics are used to define the factory class.
+     * @param classType The type to be instaniated in the newFactory or newInstance method.  Depends
+     *                  on whether a {@link ScriptContext#statefulFactoryClazz} is specified.
      * @param <T> The factory class.
      * @return A factory class that will return script instances.
      */
-    private <T> T generateFactory(Loader loader, ScriptContext<T> context) {
+    private <T> T generateFactory(Loader loader, ScriptContext<T> context, Type classType) {
         int classFrames = ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS;
         int classAccess = Opcodes.ACC_PUBLIC | Opcodes.ACC_SUPER| Opcodes.ACC_FINAL;
         String interfaceBase = Type.getType(context.factoryClazz).getInternalName();
@@ -188,25 +295,37 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
                 writer.visitMethod(Opcodes.ACC_PUBLIC, init.getName(), init.getDescriptor(), null, null));
         constructor.visitCode();
         constructor.loadThis();
-        constructor.loadArgs();
         constructor.invokeConstructor(OBJECT_TYPE, init);
         constructor.returnValue();
         constructor.endMethod();
 
-        Method reflect = context.factoryClazz.getMethods()[0];
+        Method reflect = null;
+
+        for (Method method : context.factoryClazz.getMethods()) {
+            if ("newInstance".equals(method.getName())) {
+                reflect = method;
+
+                break;
+            } else if ("newFactory".equals(method.getName())) {
+                reflect = method;
+
+                break;
+            }
+        }
+
         org.objectweb.asm.commons.Method instance = new org.objectweb.asm.commons.Method(reflect.getName(),
             MethodType.methodType(reflect.getReturnType(), reflect.getParameterTypes()).toMethodDescriptorString());
         org.objectweb.asm.commons.Method constru = new org.objectweb.asm.commons.Method("<init>",
             MethodType.methodType(void.class, reflect.getParameterTypes()).toMethodDescriptorString());
 
         GeneratorAdapter adapter = new GeneratorAdapter(Opcodes.ASM5, instance,
-                writer.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SUPER | Opcodes.ACC_FINAL,
+                writer.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL,
                                    instance.getName(), instance.getDescriptor(), null, null));
         adapter.visitCode();
-        adapter.newInstance(WriterConstants.CLASS_TYPE);
+        adapter.newInstance(classType);
         adapter.dup();
         adapter.loadArgs();
-        adapter.invokeConstructor(WriterConstants.CLASS_TYPE, constru);
+        adapter.invokeConstructor(classType, constru);
         adapter.returnValue();
         adapter.endMethod();
 

+ 43 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/FactoryTests.java

@@ -30,6 +30,7 @@ public class FactoryTests extends ScriptTestCase {
 
     protected Collection<ScriptContext<?>> scriptContexts() {
         Collection<ScriptContext<?>> contexts = super.scriptContexts();
+        contexts.add(StatefulFactoryTestScript.CONTEXT);
         contexts.add(FactoryTestScript.CONTEXT);
         contexts.add(EmptyTestScript.CONTEXT);
         contexts.add(TemplateScript.CONTEXT);
@@ -37,6 +38,48 @@ public class FactoryTests extends ScriptTestCase {
         return contexts;
     }
 
+    public abstract static class StatefulFactoryTestScript {
+        private final int x;
+        private final int y;
+
+        public StatefulFactoryTestScript(int x, int y, int a, int b) {
+            this.x = x*a;
+            this.y = y*b;
+        }
+
+        public int getX() {
+            return x;
+        }
+
+        public int getY() {
+            return y*2;
+        }
+
+        public static final String[] PARAMETERS = new String[] {"test"};
+        public abstract Object execute(int test);
+
+        public interface StatefulFactory {
+            StatefulFactoryTestScript newInstance(int a, int b);
+        }
+
+        public interface Factory {
+            StatefulFactory newFactory(int x, int y);
+        }
+
+        public static final ScriptContext<StatefulFactoryTestScript.Factory> CONTEXT =
+            new ScriptContext<>("test", StatefulFactoryTestScript.Factory.class);
+    }
+
+    public void testStatefulFactory() {
+        StatefulFactoryTestScript.Factory factory = scriptEngine.compile(
+            "stateful_factory_test", "test + x + y", StatefulFactoryTestScript.CONTEXT, Collections.emptyMap());
+        StatefulFactoryTestScript.StatefulFactory statefulFactory = factory.newFactory(1, 2);
+        StatefulFactoryTestScript script = statefulFactory.newInstance(3, 4);
+        assertEquals(22, script.execute(3));
+        statefulFactory.newInstance(5, 6);
+        assertEquals(26, script.execute(7));
+    }
+
     public abstract static class FactoryTestScript {
         private final Map<String, Object> params;