Bladeren bron

Generate Painless Factory for Creating Script Instances (#25120)

Jack Conradson 8 jaren geleden
bovenliggende
commit
d187fa78fd

+ 16 - 4
core/src/main/java/org/elasticsearch/script/TemplateScript.java

@@ -24,14 +24,26 @@ import java.util.Map;
 /**
  * A string template rendered as a script.
  */
-public interface TemplateScript {
+public abstract class TemplateScript {
 
+    private final Map<String, Object> params;
+
+    public TemplateScript(Map<String, Object> params) {
+        this.params = params;
+    }
+
+    /** Return the parameters for this script. */
+    public Map<String, Object> getParams() {
+        return params;
+    }
+
+    public static final String[] PARAMETERS = {};
     /** Run a template and return the resulting string, encoded in utf8 bytes. */
-    String execute();
+    public abstract String execute();
 
-    interface Factory {
+    public interface Factory {
         TemplateScript newInstance(Map<String, Object> params);
     }
 
-    ScriptContext<Factory> CONTEXT = new ScriptContext<>("template", Factory.class);
+    public static final ScriptContext<Factory> CONTEXT = new ScriptContext<>("template", Factory.class);
 }

+ 6 - 1
core/src/test/java/org/elasticsearch/search/suggest/SuggestSearchIT.java

@@ -1035,7 +1035,12 @@ public class SuggestSearchIT extends ESIntegTestCase {
                     script = script.replace("{{" + entry.getKey() + "}}", String.valueOf(entry.getValue()));
                 }
                 String result = script;
-                return () -> result;
+                return new TemplateScript(null) {
+                    @Override
+                    public String execute() {
+                        return result;
+                    }
+                };
             };
             return context.factoryClazz.cast(factory);
         }

+ 2 - 1
modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/MustacheScriptEngine.java

@@ -86,7 +86,7 @@ public final class MustacheScriptEngine implements ScriptEngine {
     /**
      * Used at query execution time by script service in order to execute a query template.
      * */
-    private class MustacheExecutableScript implements TemplateScript {
+    private class MustacheExecutableScript extends TemplateScript {
         /** Factory template. */
         private Mustache template;
 
@@ -96,6 +96,7 @@ public final class MustacheScriptEngine implements ScriptEngine {
          * @param template the compiled template object wrapper
          **/
         MustacheExecutableScript(Mustache template, Map<String, Object> params) {
+            super(params);
             this.template = template;
             this.params = params;
         }

+ 10 - 0
modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java

@@ -76,6 +76,16 @@ final class Compiler {
             super(parent);
         }
 
+        /**
+         * Generates a Class object from the generated byte code.
+         * @param name The name of the class.
+         * @param bytes The generated byte code.
+         * @return A Class object defining a factory.
+         */
+        Class<?> defineFactory(String name, byte[] bytes) {
+            return defineClass(name, bytes, 0, bytes.length, CODESOURCE);
+        }
+
         /**
          * Generates a Class object from the generated byte code.
          * @param name The name of the class.

+ 146 - 3
modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessScriptEngine.java

@@ -29,8 +29,14 @@ import org.elasticsearch.script.ScriptContext;
 import org.elasticsearch.script.ScriptEngine;
 import org.elasticsearch.script.ScriptException;
 import org.elasticsearch.script.SearchScript;
+import org.objectweb.asm.ClassWriter;
+import org.objectweb.asm.Opcodes;
+import org.objectweb.asm.Type;
+import org.objectweb.asm.commons.GeneratorAdapter;
 
+import java.lang.invoke.MethodType;
 import java.lang.reflect.Constructor;
+import java.lang.reflect.Method;
 import java.security.AccessControlContext;
 import java.security.AccessController;
 import java.security.Permissions;
@@ -43,6 +49,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.painless.WriterConstants.OBJECT_TYPE;
+
 /**
  * Implementation of a ScriptEngine for the Painless language.
  */
@@ -115,9 +123,10 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
 
     @Override
     public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
-        GenericElasticsearchScript painlessScript =
-            (GenericElasticsearchScript)compile(contextsToCompilers.get(context), scriptName, scriptSource, params);
         if (context.instanceClazz.equals(SearchScript.class)) {
+            GenericElasticsearchScript painlessScript =
+                (GenericElasticsearchScript)compile(contextsToCompilers.get(context), scriptName, scriptSource, params);
+
             SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
                 @Override
                 public SearchScript newInstance(final LeafReaderContext context) {
@@ -130,10 +139,87 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
             };
             return context.factoryClazz.cast(factory);
         } else if (context.instanceClazz.equals(ExecutableScript.class)) {
+            GenericElasticsearchScript painlessScript =
+                (GenericElasticsearchScript)compile(contextsToCompilers.get(context), scriptName, scriptSource, params);
+
             ExecutableScript.Factory factory = (p) -> new ScriptImpl(painlessScript, p, null, null);
             return context.factoryClazz.cast(factory);
+        } else {
+            // Check we ourselves are not being called by unprivileged code.
+            SpecialPermission.check();
+
+            // Create our loader (which loads compiled code with no permissions).
+            final Loader loader = AccessController.doPrivileged(new PrivilegedAction<Loader>() {
+                @Override
+                public Loader run() {
+                    return new Loader(getClass().getClassLoader());
+                }
+            });
+
+            compile(contextsToCompilers.get(context), loader, scriptName, scriptSource, params);
+
+            return generateFactory(loader, context);
+        }
+    }
+
+    /**
+     * Generates a factory class that will return script instances.
+     * Uses the newInstance method from a {@link ScriptContext#factoryClazz} to define the factory method
+     * to create new instances of the {@link ScriptContext#instanceClazz}.
+     * @param loader The {@link ClassLoader} that is used to define the factory class and script class.
+     * @param context The {@link ScriptContext}'s semantics are used to define the factory class.
+     * @param <T> The factory class.
+     * @return A factory class that will return script instances.
+     */
+    private <T> T generateFactory(Loader loader, ScriptContext<T> context) {
+        int classFrames = ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS;
+        int classAccess = Opcodes.ACC_PUBLIC | Opcodes.ACC_SUPER| Opcodes.ACC_FINAL;
+        String interfaceBase = Type.getType(context.factoryClazz).getInternalName();
+        String className = interfaceBase + "$Factory";
+        String classInterfaces[] = new String[] { interfaceBase };
+
+        ClassWriter writer = new ClassWriter(classFrames);
+        writer.visit(WriterConstants.CLASS_VERSION, classAccess, className, null, OBJECT_TYPE.getInternalName(), classInterfaces);
+
+        org.objectweb.asm.commons.Method init =
+            new org.objectweb.asm.commons.Method("<init>", MethodType.methodType(void.class).toMethodDescriptorString());
+
+        GeneratorAdapter constructor = new GeneratorAdapter(Opcodes.ASM5, init,
+                writer.visitMethod(Opcodes.ACC_PUBLIC, init.getName(), init.getDescriptor(), null, null));
+        constructor.visitCode();
+        constructor.loadThis();
+        constructor.loadArgs();
+        constructor.invokeConstructor(OBJECT_TYPE, init);
+        constructor.returnValue();
+        constructor.endMethod();
+
+        Method reflect = context.factoryClazz.getMethods()[0];
+        org.objectweb.asm.commons.Method instance = new org.objectweb.asm.commons.Method(reflect.getName(),
+            MethodType.methodType(reflect.getReturnType(), reflect.getParameterTypes()).toMethodDescriptorString());
+        org.objectweb.asm.commons.Method constru = new org.objectweb.asm.commons.Method("<init>",
+            MethodType.methodType(void.class, reflect.getParameterTypes()).toMethodDescriptorString());
+
+        GeneratorAdapter adapter = new GeneratorAdapter(Opcodes.ASM5, instance,
+                writer.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SUPER | Opcodes.ACC_FINAL,
+                                   instance.getName(), instance.getDescriptor(), null, null));
+        adapter.visitCode();
+        adapter.newInstance(WriterConstants.CLASS_TYPE);
+        adapter.dup();
+        adapter.loadArgs();
+        adapter.invokeConstructor(WriterConstants.CLASS_TYPE, constru);
+        adapter.returnValue();
+        adapter.endMethod();
+
+        writer.visitEnd();
+
+        Class<?> factory = loader.defineFactory(className.replace('/', '.'), writer.toByteArray());
+
+        try {
+            return context.factoryClazz.cast(factory.getConstructor().newInstance());
+        } catch (Exception exception) { // Catch everything to let the user know this is something caused internally.
+            throw new IllegalStateException(
+                "An internal error occurred attempting to define the factory class [" + className + "].", exception);
         }
-        throw new IllegalArgumentException("painless does not know how to handle context [" + context.name + "]");
     }
 
     Object compile(Compiler compiler, String scriptName, String source, Map<String, String> params, Object... args) {
@@ -209,6 +295,63 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
         }
     }
 
+    void compile(Compiler compiler, Loader loader, String scriptName, String source, Map<String, String> params) {
+        final CompilerSettings compilerSettings;
+
+        if (params.isEmpty()) {
+            // Use the default settings.
+            compilerSettings = defaultCompilerSettings;
+        } else {
+            // Use custom settings specified by params.
+            compilerSettings = new CompilerSettings();
+
+            // Except regexes enabled - this is a node level setting and can't be changed in the request.
+            compilerSettings.setRegexesEnabled(defaultCompilerSettings.areRegexesEnabled());
+
+            Map<String, String> copy = new HashMap<>(params);
+
+            String value = copy.remove(CompilerSettings.MAX_LOOP_COUNTER);
+            if (value != null) {
+                compilerSettings.setMaxLoopCounter(Integer.parseInt(value));
+            }
+
+            value = copy.remove(CompilerSettings.PICKY);
+            if (value != null) {
+                compilerSettings.setPicky(Boolean.parseBoolean(value));
+            }
+
+            value = copy.remove(CompilerSettings.INITIAL_CALL_SITE_DEPTH);
+            if (value != null) {
+                compilerSettings.setInitialCallSiteDepth(Integer.parseInt(value));
+            }
+
+            value = copy.remove(CompilerSettings.REGEX_ENABLED.getKey());
+            if (value != null) {
+                throw new IllegalArgumentException("[painless.regex.enabled] can only be set on node startup.");
+            }
+
+            if (!copy.isEmpty()) {
+                throw new IllegalArgumentException("Unrecognized compile-time parameter(s): " + copy);
+            }
+        }
+
+        try {
+            // Drop all permissions to actually compile the code itself.
+            AccessController.doPrivileged(new PrivilegedAction<Void>() {
+                @Override
+                public Void run() {
+                    String name = scriptName == null ? INLINE_NAME : scriptName;
+                    compiler.compile(loader, name, source, compilerSettings);
+
+                    return null;
+                }
+            }, COMPILATION_CONTEXT);
+            // Note that it is safe to catch any of the following errors since Painless is stateless.
+        } catch (OutOfMemoryError | StackOverflowError | VerifyError | Exception e) {
+            throw convertToScriptException(scriptName == null ? source : scriptName, source, e);
+        }
+    }
+
     private ScriptException convertToScriptException(String scriptName, String scriptSource, Throwable t) {
         // create a script stack: this is just the script portion
         List<String> scriptStack = new ArrayList<>();

+ 105 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/FactoryTests.java

@@ -0,0 +1,105 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.painless;
+
+import org.elasticsearch.script.ScriptContext;
+import org.elasticsearch.script.TemplateScript;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+
+public class FactoryTests extends ScriptTestCase {
+
+    protected Collection<ScriptContext<?>> scriptContexts() {
+        Collection<ScriptContext<?>> contexts = super.scriptContexts();
+        contexts.add(FactoryTestScript.CONTEXT);
+        contexts.add(EmptyTestScript.CONTEXT);
+        contexts.add(TemplateScript.CONTEXT);
+
+        return contexts;
+    }
+
+    public abstract static class FactoryTestScript {
+        private final Map<String, Object> params;
+
+        public FactoryTestScript(Map<String, Object> params) {
+            this.params = params;
+        }
+
+        public Map<String, Object> getParams() {
+            return params;
+        }
+
+        public static final String[] PARAMETERS = new String[] {"test"};
+        public abstract Object execute(int test);
+
+        public interface Factory {
+            FactoryTestScript newInstance(Map<String, Object> params);
+        }
+
+        public static final ScriptContext<FactoryTestScript.Factory> CONTEXT =
+            new ScriptContext<>("test", FactoryTestScript.Factory.class);
+    }
+
+    public void testFactory() {
+        FactoryTestScript.Factory factory =
+            scriptEngine.compile("factory_test", "test + params.get('test')", FactoryTestScript.CONTEXT, Collections.emptyMap());
+        FactoryTestScript script = factory.newInstance(Collections.singletonMap("test", 2));
+        assertEquals(4, script.execute(2));
+        assertEquals(5, script.execute(3));
+        script = factory.newInstance(Collections.singletonMap("test", 3));
+        assertEquals(5, script.execute(2));
+        assertEquals(2, script.execute(-1));
+    }
+
+    public abstract static class EmptyTestScript {
+        public static final String[] PARAMETERS = {};
+        public abstract Object execute();
+
+        public interface Factory {
+            EmptyTestScript newInstance();
+        }
+
+        public static final ScriptContext<EmptyTestScript.Factory> CONTEXT =
+            new ScriptContext<>("test", EmptyTestScript.Factory.class);
+    }
+
+    public void testEmpty() {
+        EmptyTestScript.Factory factory = scriptEngine.compile("empty_test", "1", EmptyTestScript.CONTEXT, Collections.emptyMap());
+        EmptyTestScript script = factory.newInstance();
+        assertEquals(1, script.execute());
+        assertEquals(1, script.execute());
+        script = factory.newInstance();
+        assertEquals(1, script.execute());
+        assertEquals(1, script.execute());
+    }
+
+    public void testTemplate() {
+        TemplateScript.Factory factory =
+            scriptEngine.compile("template_test", "params['test']", TemplateScript.CONTEXT, Collections.emptyMap());
+        TemplateScript script = factory.newInstance(Collections.singletonMap("test", "abc"));
+        assertEquals("abc", script.execute());
+        assertEquals("abc", script.execute());
+        script = factory.newInstance(Collections.singletonMap("test", "def"));
+        assertEquals("def", script.execute());
+        assertEquals("def", script.execute());
+    }
+}

+ 7 - 2
test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java

@@ -86,8 +86,13 @@ public class MockScriptEngine implements ScriptEngine {
                 // TODO: need a better way to implement all these new contexts
                 // this is just a shim to act as an executable script just as before
                 ExecutableScript execScript = mockCompiled.createExecutableScript(vars);
-                return () -> (String) execScript.run();
-            };
+                    return new TemplateScript(vars) {
+                        @Override
+                        public String execute() {
+                            return (String) execScript.run();
+                        }
+                    };
+                };
             return context.factoryClazz.cast(factory);
         }
         throw new IllegalArgumentException("mock script engine does not know how to handle context [" + context.name + "]");