Browse Source

Make Painless Compiler Use an Instance Per Context (#24972)

Allows for easier management of compilation of individual interfaces on a per script context basis.
Jack Conradson 8 years ago
parent
commit
04daac2243

+ 27 - 19
modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java

@@ -24,6 +24,7 @@ import org.elasticsearch.painless.antlr.Walker;
 import org.elasticsearch.painless.node.SSource;
 import org.objectweb.asm.util.Printer;
 
+import java.lang.reflect.Constructor;
 import java.net.MalformedURLException;
 import java.net.URL;
 import java.security.CodeSource;
@@ -104,28 +105,44 @@ final class Compiler {
         }
     }
 
+    /**
+     * The class/interface the script is guaranteed to derive/implement.
+     */
+    private final Class<?> base;
+
+    /**
+     * The whitelist the script will use.
+     */
+    private final Definition definition;
+
+    /**
+     * Standard constructor.
+     * @param base The class/interface the script is guaranteed to derive/implement.
+     * @param definition The whitelist the script will use.
+     */
+    Compiler(Class<?> base, Definition definition) {
+        this.base = base;
+        this.definition = definition;
+    }
+
     /**
      * Runs the two-pass compiler to generate a Painless script.
-     * @param <T> the type of the script
      * @param loader The ClassLoader used to define the script.
-     * @param iface Interface the compiled script should implement
      * @param name The name of the script.
      * @param source The source code for the script.
      * @param settings The CompilerSettings to be used during the compilation.
-     * @return An executable script that implements both {@code <T>} and is a subclass of {@link PainlessScript}
+     * @return An executable script that implements both a specified interface and is a subclass of {@link PainlessScript}
      */
-    static <T> T compile(Loader loader, Class<T> iface, String name, String source, CompilerSettings settings) {
+    Constructor<? extends PainlessScript> compile(Loader loader, String name, String source, CompilerSettings settings) {
         if (source.length() > MAXIMUM_SOURCE_LENGTH) {
             throw new IllegalArgumentException("Scripts may be no longer than " + MAXIMUM_SOURCE_LENGTH +
                 " characters.  The passed in script is " + source.length() + " characters.  Consider using a" +
                 " plugin if a script longer than this length is a requirement.");
         }
-        Definition definition = Definition.BUILTINS;
-        ScriptInterface scriptInterface = new ScriptInterface(definition, iface);
 
+        ScriptInterface scriptInterface = new ScriptInterface(definition, base);
         SSource root = Walker.buildPainlessTree(scriptInterface, name, source, settings, definition,
                 null);
-
         root.analyze(definition);
         root.write();
 
@@ -135,9 +152,8 @@ final class Compiler {
             clazz.getField("$SOURCE").set(null, source);
             clazz.getField("$STATEMENTS").set(null, root.getStatements());
             clazz.getField("$DEFINITION").set(null, definition);
-            java.lang.reflect.Constructor<? extends PainlessScript> constructor = clazz.getConstructor();
 
-            return iface.cast(constructor.newInstance());
+            return clazz.getConstructor();
         } catch (Exception exception) { // Catch everything to let the user know this is something caused internally.
             throw new IllegalStateException("An internal error occurred attempting to define the script [" + name + "].", exception);
         }
@@ -145,31 +161,23 @@ final class Compiler {
 
     /**
      * Runs the two-pass compiler to generate a Painless script.  (Used by the debugger.)
-     * @param iface Interface the compiled script should implement
      * @param source The source code for the script.
      * @param settings The CompilerSettings to be used during the compilation.
      * @return The bytes for compilation.
      */
-    static byte[] compile(Class<?> iface, String name, String source, CompilerSettings settings, Printer debugStream) {
+    byte[] compile(String name, String source, CompilerSettings settings, Printer debugStream) {
         if (source.length() > MAXIMUM_SOURCE_LENGTH) {
             throw new IllegalArgumentException("Scripts may be no longer than " + MAXIMUM_SOURCE_LENGTH +
                 " characters.  The passed in script is " + source.length() + " characters.  Consider using a" +
                 " plugin if a script longer than this length is a requirement.");
         }
-        Definition definition = Definition.BUILTINS;
-        ScriptInterface scriptInterface = new ScriptInterface(definition, iface);
 
+        ScriptInterface scriptInterface = new ScriptInterface(definition, base);
         SSource root = Walker.buildPainlessTree(scriptInterface, name, source, settings, definition,
                 debugStream);
-
         root.analyze(definition);
         root.write();
 
         return root.getBytes();
     }
-
-    /**
-     * All methods in the compiler should be static.
-     */
-    private Compiler() {}
 }

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessPlugin.java

@@ -43,7 +43,7 @@ public final class PainlessPlugin extends Plugin implements ScriptPlugin {
 
     @Override
     public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
-        return new PainlessScriptEngine(settings);
+        return new PainlessScriptEngine(settings, contexts);
     }
 
     @Override

+ 33 - 7
modules/lang-painless/src/main/java/org/elasticsearch/painless/PainlessScriptEngine.java

@@ -32,12 +32,15 @@ import org.elasticsearch.script.ScriptException;
 import org.elasticsearch.script.SearchScript;
 
 import java.io.IOException;
+import java.lang.reflect.Constructor;
 import java.security.AccessControlContext;
 import java.security.AccessController;
 import java.security.Permissions;
 import java.security.PrivilegedAction;
 import java.security.ProtectionDomain;
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -70,17 +73,32 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
 
     /**
      * Default compiler settings to be used. Note that {@link CompilerSettings} is mutable but this instance shouldn't be mutated outside
-     * of {@link PainlessScriptEngine#PainlessScriptEngine(Settings)}.
+     * of {@link PainlessScriptEngine#PainlessScriptEngine(Settings, Collection)}.
      */
     private final CompilerSettings defaultCompilerSettings = new CompilerSettings();
 
+    private final Map<ScriptContext<?>, Compiler> contextsToCompilers;
+
     /**
      * Constructor.
      * @param settings The settings to initialize the engine with.
      */
-    public PainlessScriptEngine(final Settings settings) {
+    public PainlessScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
         super(settings);
+
         defaultCompilerSettings.setRegexesEnabled(CompilerSettings.REGEX_ENABLED.get(settings));
+
+        Map<ScriptContext<?>, Compiler> contextsToCompilers = new HashMap<>();
+
+        for (ScriptContext<?> context : contexts) {
+            if (context.instanceClazz.equals(SearchScript.class) || context.instanceClazz.equals(ExecutableScript.class)) {
+                contextsToCompilers.put(context, new Compiler(GenericElasticsearchScript.class, Definition.BUILTINS));
+            } else {
+                contextsToCompilers.put(context, new Compiler(context.instanceClazz, Definition.BUILTINS));
+            }
+        }
+
+        this.contextsToCompilers = Collections.unmodifiableMap(contextsToCompilers);
     }
 
     /**
@@ -99,7 +117,8 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
 
     @Override
     public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
-        GenericElasticsearchScript painlessScript = compile(GenericElasticsearchScript.class, scriptName, scriptSource, params);
+        GenericElasticsearchScript painlessScript =
+            (GenericElasticsearchScript)compile(contextsToCompilers.get(context), scriptName, scriptSource, params);
         if (context.instanceClazz.equals(SearchScript.class)) {
             SearchScript.Factory factory = (p, lookup) -> new SearchScript() {
                 @Override
@@ -119,7 +138,7 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
         throw new IllegalArgumentException("painless does not know how to handle context [" + context.name + "]");
     }
 
-    <T> T compile(Class<T> iface, String scriptName, final String scriptSource, final Map<String, String> params) {
+    PainlessScript compile(Compiler compiler, String scriptName, final String scriptSource, final Map<String, String> params) {
         final CompilerSettings compilerSettings;
 
         if (params.isEmpty()) {
@@ -172,11 +191,18 @@ public final class PainlessScriptEngine extends AbstractComponent implements Scr
 
         try {
             // Drop all permissions to actually compile the code itself.
-            return AccessController.doPrivileged(new PrivilegedAction<T>() {
+            return AccessController.doPrivileged(new PrivilegedAction<PainlessScript>() {
                 @Override
-                public T run() {
+                public PainlessScript run() {
                     String name = scriptName == null ? INLINE_NAME : scriptName;
-                    return Compiler.compile(loader, iface, name, scriptSource, compilerSettings);
+                    Constructor<? extends PainlessScript> constructor = compiler.compile(loader, name, scriptSource, compilerSettings);
+
+                    try {
+                        return constructor.newInstance();
+                    } catch (Exception exception) { // Catch everything to let the user know this is something caused internally.
+                        throw new IllegalStateException(
+                            "An internal error occurred attempting to define the script [" + name + "].", exception);
+                    }
                 }
             }, COMPILATION_CONTEXT);
         // Note that it is safe to catch any of the following errors since Painless is stateless.

+ 2 - 2
modules/lang-painless/src/test/java/org/elasticsearch/painless/Debugger.java

@@ -39,14 +39,14 @@ final class Debugger {
         PrintWriter outputWriter = new PrintWriter(output);
         Textifier textifier = new Textifier();
         try {
-            Compiler.compile(iface, "<debugging>", source, settings, textifier);
+            new Compiler(iface, Definition.BUILTINS).compile("<debugging>", source, settings, textifier);
         } catch (Exception e) {
             textifier.print(outputWriter);
             e.addSuppressed(new Exception("current bytecode: \n" + output));
             IOUtils.reThrowUnchecked(e);
             throw new AssertionError();
         }
-        
+
         textifier.print(outputWriter);
         return output.toString();
     }

+ 124 - 89
modules/lang-painless/src/test/java/org/elasticsearch/painless/ImplementInterfacesTests.java

@@ -19,6 +19,9 @@
 
 package org.elasticsearch.painless;
 
+import org.elasticsearch.script.ScriptContext;
+
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -37,15 +40,16 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute();
     }
     public void testNoArgs() {
-        assertEquals(1, scriptEngine.compile(NoArgs.class, null, "1", emptyMap()).execute());
-        assertEquals("foo", scriptEngine.compile(NoArgs.class, null, "'foo'", emptyMap()).execute());
+        Compiler compiler = new Compiler(NoArgs.class, Definition.BUILTINS);
+        assertEquals(1, ((NoArgs)scriptEngine.compile(compiler, null, "1", emptyMap())).execute());
+        assertEquals("foo", ((NoArgs)scriptEngine.compile(compiler, null, "'foo'", emptyMap())).execute());
 
         Exception e = expectScriptThrows(IllegalArgumentException.class, () ->
-            scriptEngine.compile(NoArgs.class, null, "doc", emptyMap()));
+                scriptEngine.compile(compiler, null, "doc", emptyMap()));
         assertEquals("Variable [doc] is not defined.", e.getMessage());
         // _score was once embedded into painless by deep magic
         e = expectScriptThrows(IllegalArgumentException.class, () ->
-            scriptEngine.compile(NoArgs.class, null, "_score", emptyMap()));
+                scriptEngine.compile(compiler, null, "_score", emptyMap()));
         assertEquals("Variable [_score] is not defined.", e.getMessage());
 
         String debug = Debugger.toString(NoArgs.class, "int i = 0", new CompilerSettings());
@@ -60,17 +64,19 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(Object arg);
     }
     public void testOneArg() {
+        Compiler compiler = new Compiler(OneArg.class, Definition.BUILTINS);
         Object rando = randomInt();
-        assertEquals(rando, scriptEngine.compile(OneArg.class, null, "arg", emptyMap()).execute(rando));
+        assertEquals(rando, ((OneArg)scriptEngine.compile(compiler, null, "arg", emptyMap())).execute(rando));
         rando = randomAlphaOfLength(5);
-        assertEquals(rando, scriptEngine.compile(OneArg.class, null, "arg", emptyMap()).execute(rando));
+        assertEquals(rando, ((OneArg)scriptEngine.compile(compiler, null, "arg", emptyMap())).execute(rando));
 
+        Compiler noargs = new Compiler(NoArgs.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, () ->
-            scriptEngine.compile(NoArgs.class, null, "doc", emptyMap()));
+                scriptEngine.compile(noargs, null, "doc", emptyMap()));
         assertEquals("Variable [doc] is not defined.", e.getMessage());
         // _score was once embedded into painless by deep magic
         e = expectScriptThrows(IllegalArgumentException.class, () ->
-            scriptEngine.compile(NoArgs.class, null, "_score", emptyMap()));
+                scriptEngine.compile(noargs, null, "_score", emptyMap()));
         assertEquals("Variable [_score] is not defined.", e.getMessage());
     }
 
@@ -79,8 +85,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(String[] arg);
     }
     public void testArrayArg() {
+        Compiler compiler = new Compiler(ArrayArg.class, Definition.BUILTINS);
         String rando = randomAlphaOfLength(5);
-        assertEquals(rando, scriptEngine.compile(ArrayArg.class, null, "arg[0]", emptyMap()).execute(new String[] {rando, "foo"}));
+        assertEquals(rando, ((ArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new String[] {rando, "foo"}));
     }
 
     public interface PrimitiveArrayArg {
@@ -88,8 +95,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(int[] arg);
     }
     public void testPrimitiveArrayArg() {
+        Compiler compiler = new Compiler(PrimitiveArrayArg.class, Definition.BUILTINS);
         int rando = randomInt();
-        assertEquals(rando, scriptEngine.compile(PrimitiveArrayArg.class, null, "arg[0]", emptyMap()).execute(new int[] {rando, 10}));
+        assertEquals(rando, ((PrimitiveArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new int[] {rando, 10}));
     }
 
     public interface DefArrayArg {
@@ -97,11 +105,13 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(Object[] arg);
     }
     public void testDefArrayArg() {
+        Compiler compiler = new Compiler(DefArrayArg.class, Definition.BUILTINS);
         Object rando = randomInt();
-        assertEquals(rando, scriptEngine.compile(DefArrayArg.class, null, "arg[0]", emptyMap()).execute(new Object[] {rando, 10}));
+        assertEquals(rando, ((DefArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new Object[] {rando, 10}));
         rando = randomAlphaOfLength(5);
-        assertEquals(rando, scriptEngine.compile(DefArrayArg.class, null, "arg[0]", emptyMap()).execute(new Object[] {rando, 10}));
-        assertEquals(5, scriptEngine.compile(DefArrayArg.class, null, "arg[0].length()", emptyMap()).execute(new Object[] {rando, 10}));
+        assertEquals(rando, ((DefArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new Object[] {rando, 10}));
+        assertEquals(5,
+            ((DefArrayArg)scriptEngine.compile(compiler, null, "arg[0].length()", emptyMap())).execute(new Object[] {rando, 10}));
     }
 
     public interface ManyArgs {
@@ -113,22 +123,23 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         boolean uses$d();
     }
     public void testManyArgs() {
+        Compiler compiler = new Compiler(ManyArgs.class, Definition.BUILTINS);
         int rando = randomInt();
-        assertEquals(rando, scriptEngine.compile(ManyArgs.class, null, "a", emptyMap()).execute(rando, 0, 0, 0));
-        assertEquals(10, scriptEngine.compile(ManyArgs.class, null, "a + b + c + d", emptyMap()).execute(1, 2, 3, 4));
+        assertEquals(rando, ((ManyArgs)scriptEngine.compile(compiler, null, "a", emptyMap())).execute(rando, 0, 0, 0));
+        assertEquals(10, ((ManyArgs)scriptEngine.compile(compiler, null, "a + b + c + d", emptyMap())).execute(1, 2, 3, 4));
 
         // While we're here we can verify that painless correctly finds used variables
-        ManyArgs script = scriptEngine.compile(ManyArgs.class, null, "a", emptyMap());
+        ManyArgs script = (ManyArgs)scriptEngine.compile(compiler, null, "a", emptyMap());
         assertTrue(script.uses$a());
         assertFalse(script.uses$b());
         assertFalse(script.uses$c());
         assertFalse(script.uses$d());
-        script = scriptEngine.compile(ManyArgs.class, null, "a + b + c", emptyMap());
+        script = (ManyArgs)scriptEngine.compile(compiler, null, "a + b + c", emptyMap());
         assertTrue(script.uses$a());
         assertTrue(script.uses$b());
         assertTrue(script.uses$c());
         assertFalse(script.uses$d());
-        script = scriptEngine.compile(ManyArgs.class, null, "a + b + c + d", emptyMap());
+        script = (ManyArgs)scriptEngine.compile(compiler, null, "a + b + c + d", emptyMap());
         assertTrue(script.uses$a());
         assertTrue(script.uses$b());
         assertTrue(script.uses$c());
@@ -140,7 +151,8 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(String... arg);
     }
     public void testVararg() {
-        assertEquals("foo bar baz", scriptEngine.compile(VarargTest.class, null, "String.join(' ', Arrays.asList(arg))", emptyMap())
+        Compiler compiler = new Compiler(VarargTest.class, Definition.BUILTINS);
+        assertEquals("foo bar baz", ((VarargTest)scriptEngine.compile(compiler, null, "String.join(' ', Arrays.asList(arg))", emptyMap()))
                     .execute("foo", "bar", "baz"));
     }
 
@@ -155,12 +167,13 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         }
     }
     public void testDefaultMethods() {
+        Compiler compiler = new Compiler(DefaultMethods.class, Definition.BUILTINS);
         int rando = randomInt();
-        assertEquals(rando, scriptEngine.compile(DefaultMethods.class, null, "a", emptyMap()).execute(rando, 0, 0, 0));
-        assertEquals(rando, scriptEngine.compile(DefaultMethods.class, null, "a", emptyMap()).executeWithASingleOne(rando, 0, 0));
-        assertEquals(10, scriptEngine.compile(DefaultMethods.class, null, "a + b + c + d", emptyMap()).execute(1, 2, 3, 4));
-        assertEquals(4, scriptEngine.compile(DefaultMethods.class, null, "a + b + c + d", emptyMap()).executeWithOne());
-        assertEquals(7, scriptEngine.compile(DefaultMethods.class, null, "a + b + c + d", emptyMap()).executeWithASingleOne(1, 2, 3));
+        assertEquals(rando, ((DefaultMethods)scriptEngine.compile(compiler, null, "a", emptyMap())).execute(rando, 0, 0, 0));
+        assertEquals(rando, ((DefaultMethods)scriptEngine.compile(compiler, null, "a", emptyMap())).executeWithASingleOne(rando, 0, 0));
+        assertEquals(10, ((DefaultMethods)scriptEngine.compile(compiler, null, "a + b + c + d", emptyMap())).execute(1, 2, 3, 4));
+        assertEquals(4, ((DefaultMethods)scriptEngine.compile(compiler, null, "a + b + c + d", emptyMap())).executeWithOne());
+        assertEquals(7, ((DefaultMethods)scriptEngine.compile(compiler, null, "a + b + c + d", emptyMap())).executeWithASingleOne(1, 2, 3));
     }
 
     public interface ReturnsVoid {
@@ -168,10 +181,11 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         void execute(Map<String, Object> map);
     }
     public void testReturnsVoid() {
+        Compiler compiler = new Compiler(ReturnsVoid.class, Definition.BUILTINS);
         Map<String, Object> map = new HashMap<>();
-        scriptEngine.compile(ReturnsVoid.class, null, "map.a = 'foo'", emptyMap()).execute(map);
+        ((ReturnsVoid)scriptEngine.compile(compiler, null, "map.a = 'foo'", emptyMap())).execute(map);
         assertEquals(singletonMap("a", "foo"), map);
-        scriptEngine.compile(ReturnsVoid.class, null, "map.remove('a')", emptyMap()).execute(map);
+        ((ReturnsVoid)scriptEngine.compile(compiler, null, "map.remove('a')", emptyMap())).execute(map);
         assertEquals(emptyMap(), map);
 
         String debug = Debugger.toString(ReturnsVoid.class, "int i = 0", new CompilerSettings());
@@ -186,15 +200,18 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         boolean execute();
     }
     public void testReturnsPrimitiveBoolean() {
-        assertEquals(true, scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "true", emptyMap()).execute());
-        assertEquals(false, scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "false", emptyMap()).execute());
-        assertEquals(true, scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "Boolean.TRUE", emptyMap()).execute());
-        assertEquals(false, scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "Boolean.FALSE", emptyMap()).execute());
+        Compiler compiler = new Compiler(ReturnsPrimitiveBoolean.class, Definition.BUILTINS);
+
+        assertEquals(true, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "true", emptyMap())).execute());
+        assertEquals(false, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "false", emptyMap())).execute());
+        assertEquals(true, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "Boolean.TRUE", emptyMap())).execute());
+        assertEquals(false, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "Boolean.FALSE", emptyMap())).execute());
 
-        assertEquals(true, scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "def i = true; i", emptyMap()).execute());
-        assertEquals(true, scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "def i = Boolean.TRUE; i", emptyMap()).execute());
+        assertEquals(true, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "def i = true; i", emptyMap())).execute());
+        assertEquals(true,
+                ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "def i = Boolean.TRUE; i", emptyMap())).execute());
 
-        assertEquals(true, scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "true || false", emptyMap()).execute());
+        assertEquals(true, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "true || false", emptyMap())).execute());
 
         String debug = Debugger.toString(ReturnsPrimitiveBoolean.class, "false", new CompilerSettings());
         assertThat(debug, containsString("ICONST_0"));
@@ -202,22 +219,22 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         assertThat(debug, containsString("IRETURN"));
 
         Exception e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "1L", emptyMap()).execute());
+                ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "1L", emptyMap())).execute());
         assertEquals("Cannot cast from [long] to [boolean].", e.getMessage());
         e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "1.1f", emptyMap()).execute());
+                ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "1.1f", emptyMap())).execute());
         assertEquals("Cannot cast from [float] to [boolean].", e.getMessage());
         e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "1.1d", emptyMap()).execute());
+                ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "1.1d", emptyMap())).execute());
         assertEquals("Cannot cast from [double] to [boolean].", e.getMessage());
         expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "def i = 1L; i", emptyMap()).execute());
+                ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "def i = 1L; i", emptyMap())).execute());
         expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "def i = 1.1f; i", emptyMap()).execute());
+                ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "def i = 1.1f; i", emptyMap())).execute());
         expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "def i = 1.1d; i", emptyMap()).execute());
+                ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "def i = 1.1d; i", emptyMap())).execute());
 
-        assertEquals(false, scriptEngine.compile(ReturnsPrimitiveBoolean.class, null, "int i = 0", emptyMap()).execute());
+        assertEquals(false, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "int i = 0", emptyMap())).execute());
     }
 
     public interface ReturnsPrimitiveInt {
@@ -225,16 +242,18 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         int execute();
     }
     public void testReturnsPrimitiveInt() {
-        assertEquals(1, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "1", emptyMap()).execute());
-        assertEquals(1, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "(int) 1L", emptyMap()).execute());
-        assertEquals(1, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "(int) 1.1d", emptyMap()).execute());
-        assertEquals(1, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "(int) 1.1f", emptyMap()).execute());
-        assertEquals(1, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "Integer.valueOf(1)", emptyMap()).execute());
+        Compiler compiler = new Compiler(ReturnsPrimitiveInt.class, Definition.BUILTINS);
+
+        assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "1", emptyMap())).execute());
+        assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "(int) 1L", emptyMap())).execute());
+        assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "(int) 1.1d", emptyMap())).execute());
+        assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "(int) 1.1f", emptyMap())).execute());
+        assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "Integer.valueOf(1)", emptyMap())).execute());
 
-        assertEquals(1, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "def i = 1; i", emptyMap()).execute());
-        assertEquals(1, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "def i = Integer.valueOf(1); i", emptyMap()).execute());
+        assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "def i = 1; i", emptyMap())).execute());
+        assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "def i = Integer.valueOf(1); i", emptyMap())).execute());
 
-        assertEquals(2, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "1 + 1", emptyMap()).execute());
+        assertEquals(2, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "1 + 1", emptyMap())).execute());
 
         String debug = Debugger.toString(ReturnsPrimitiveInt.class, "1", new CompilerSettings());
         assertThat(debug, containsString("ICONST_1"));
@@ -242,22 +261,22 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         assertThat(debug, containsString("IRETURN"));
 
         Exception e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveInt.class, null, "1L", emptyMap()).execute());
+                ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "1L", emptyMap())).execute());
         assertEquals("Cannot cast from [long] to [int].", e.getMessage());
         e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveInt.class, null, "1.1f", emptyMap()).execute());
+                ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "1.1f", emptyMap())).execute());
         assertEquals("Cannot cast from [float] to [int].", e.getMessage());
         e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveInt.class, null, "1.1d", emptyMap()).execute());
+                ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "1.1d", emptyMap())).execute());
         assertEquals("Cannot cast from [double] to [int].", e.getMessage());
         expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveInt.class, null, "def i = 1L; i", emptyMap()).execute());
+                ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "def i = 1L; i", emptyMap())).execute());
         expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveInt.class, null, "def i = 1.1f; i", emptyMap()).execute());
+                ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "def i = 1.1f; i", emptyMap())).execute());
         expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveInt.class, null, "def i = 1.1d; i", emptyMap()).execute());
+                ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "def i = 1.1d; i", emptyMap())).execute());
 
-        assertEquals(0, scriptEngine.compile(ReturnsPrimitiveInt.class, null, "int i = 0", emptyMap()).execute());
+        assertEquals(0, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "int i = 0", emptyMap())).execute());
     }
 
     public interface ReturnsPrimitiveFloat {
@@ -265,28 +284,30 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         float execute();
     }
     public void testReturnsPrimitiveFloat() {
-        assertEquals(1.1f, scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "1.1f", emptyMap()).execute(), 0);
-        assertEquals(1.1f, scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "(float) 1.1d", emptyMap()).execute(), 0);
-        assertEquals(1.1f, scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "def d = 1.1f; d", emptyMap()).execute(), 0);
+        Compiler compiler = new Compiler(ReturnsPrimitiveFloat.class, Definition.BUILTINS);
+
+        assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "1.1f", emptyMap())).execute(), 0);
+        assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "(float) 1.1d", emptyMap())).execute(), 0);
+        assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "def d = 1.1f; d", emptyMap())).execute(), 0);
         assertEquals(1.1f,
-                scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "def d = Float.valueOf(1.1f); d", emptyMap()).execute(), 0);
+                ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "def d = Float.valueOf(1.1f); d", emptyMap())).execute(), 0);
 
-        assertEquals(1.1f + 6.7f, scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "1.1f + 6.7f", emptyMap()).execute(), 0);
+        assertEquals(1.1f + 6.7f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "1.1f + 6.7f", emptyMap())).execute(), 0);
 
         Exception e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "1.1d", emptyMap()).execute());
+                ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "1.1d", emptyMap())).execute());
         assertEquals("Cannot cast from [double] to [float].", e.getMessage());
         e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "def d = 1.1d; d", emptyMap()).execute());
+                ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "def d = 1.1d; d", emptyMap())).execute());
         e = expectScriptThrows(ClassCastException.class, () ->
-                scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "def d = Double.valueOf(1.1); d", emptyMap()).execute());
+                ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "def d = Double.valueOf(1.1); d", emptyMap())).execute());
 
         String debug = Debugger.toString(ReturnsPrimitiveFloat.class, "1f", new CompilerSettings());
         assertThat(debug, containsString("FCONST_1"));
         // The important thing here is that we have the bytecode for returning a float instead of an object
         assertThat(debug, containsString("FRETURN"));
 
-        assertEquals(0.0f, scriptEngine.compile(ReturnsPrimitiveFloat.class, null, "int i = 0", emptyMap()).execute(), 0);
+        assertEquals(0.0f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "int i = 0", emptyMap())).execute(), 0);
     }
 
     public interface ReturnsPrimitiveDouble {
@@ -294,39 +315,43 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         double execute();
     }
     public void testReturnsPrimitiveDouble() {
-        assertEquals(1.0, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "1", emptyMap()).execute(), 0);
-        assertEquals(1.0, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "1L", emptyMap()).execute(), 0);
-        assertEquals(1.1, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "1.1d", emptyMap()).execute(), 0);
-        assertEquals((double) 1.1f, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "1.1f", emptyMap()).execute(), 0);
-        assertEquals(1.1, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "Double.valueOf(1.1)", emptyMap()).execute(), 0);
+        Compiler compiler = new Compiler(ReturnsPrimitiveDouble.class, Definition.BUILTINS);
+
+        assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1", emptyMap())).execute(), 0);
+        assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1L", emptyMap())).execute(), 0);
+        assertEquals(1.1, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1.1d", emptyMap())).execute(), 0);
+        assertEquals((double) 1.1f, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1.1f", emptyMap())).execute(), 0);
+        assertEquals(1.1, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "Double.valueOf(1.1)", emptyMap())).execute(), 0);
         assertEquals((double) 1.1f,
-                    scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "Float.valueOf(1.1f)", emptyMap()).execute(), 0);
+               ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "Float.valueOf(1.1f)", emptyMap())).execute(), 0);
 
-        assertEquals(1.0, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "def d = 1; d", emptyMap()).execute(), 0);
-        assertEquals(1.0, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "def d = 1L; d", emptyMap()).execute(), 0);
-        assertEquals(1.1, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "def d = 1.1d; d", emptyMap()).execute(), 0);
-        assertEquals((double) 1.1f, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "def d = 1.1f; d", emptyMap()).execute(), 0);
+        assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "def d = 1; d", emptyMap())).execute(), 0);
+        assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "def d = 1L; d", emptyMap())).execute(), 0);
+        assertEquals(1.1, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "def d = 1.1d; d", emptyMap())).execute(), 0);
+        assertEquals((double) 1.1f,
+                ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "def d = 1.1f; d", emptyMap())).execute(), 0);
         assertEquals(1.1,
-                scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "def d = Double.valueOf(1.1); d", emptyMap()).execute(), 0);
+                ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "def d = Double.valueOf(1.1); d", emptyMap())).execute(), 0);
         assertEquals((double) 1.1f,
-                scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "def d = Float.valueOf(1.1f); d", emptyMap()).execute(), 0);
+                ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "def d = Float.valueOf(1.1f); d", emptyMap())).execute(), 0);
 
-        assertEquals(1.1 + 6.7, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "1.1 + 6.7", emptyMap()).execute(), 0);
+        assertEquals(1.1 + 6.7, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1.1 + 6.7", emptyMap())).execute(), 0);
 
         String debug = Debugger.toString(ReturnsPrimitiveDouble.class, "1", new CompilerSettings());
         assertThat(debug, containsString("DCONST_1"));
         // The important thing here is that we have the bytecode for returning a double instead of an object
         assertThat(debug, containsString("DRETURN"));
 
-        assertEquals(0.0, scriptEngine.compile(ReturnsPrimitiveDouble.class, null, "int i = 0", emptyMap()).execute(), 0);
+        assertEquals(0.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "int i = 0", emptyMap())).execute(), 0);
     }
 
     public interface NoArgumentsConstant {
         Object execute(String foo);
     }
     public void testNoArgumentsConstant() {
+        Compiler compiler = new Compiler(NoArgumentsConstant.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(NoArgumentsConstant.class, null, "1", emptyMap()));
+            scriptEngine.compile(compiler, null, "1", emptyMap()));
         assertThat(e.getMessage(), startsWith("Painless needs a constant [String[] ARGUMENTS] on all interfaces it implements with the "
                 + "names of the method arguments but [" + NoArgumentsConstant.class.getName() + "] doesn't have one."));
     }
@@ -336,8 +361,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(String foo);
     }
     public void testWrongArgumentsConstant() {
+        Compiler compiler = new Compiler(WrongArgumentsConstant.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(WrongArgumentsConstant.class, null, "1", emptyMap()));
+            scriptEngine.compile(compiler, null, "1", emptyMap()));
         assertThat(e.getMessage(), startsWith("Painless needs a constant [String[] ARGUMENTS] on all interfaces it implements with the "
                 + "names of the method arguments but [" + WrongArgumentsConstant.class.getName() + "] doesn't have one."));
     }
@@ -347,8 +373,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(String foo);
     }
     public void testWrongLengthOfArgumentConstant() {
+        Compiler compiler = new Compiler(WrongLengthOfArgumentConstant.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(WrongLengthOfArgumentConstant.class, null, "1", emptyMap()));
+            scriptEngine.compile(compiler, null, "1", emptyMap()));
         assertThat(e.getMessage(), startsWith("[" + WrongLengthOfArgumentConstant.class.getName() + "#ARGUMENTS] has length [2] but ["
                 + WrongLengthOfArgumentConstant.class.getName() + "#execute] takes [1] argument."));
     }
@@ -358,8 +385,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(UnknownArgType foo);
     }
     public void testUnknownArgType() {
+        Compiler compiler = new Compiler(UnknownArgType.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(UnknownArgType.class, null, "1", emptyMap()));
+            scriptEngine.compile(compiler, null, "1", emptyMap()));
         assertEquals("[foo] is of unknown type [" + UnknownArgType.class.getName() + ". Painless interfaces can only accept arguments "
                 + "that are of whitelisted types.", e.getMessage());
     }
@@ -369,8 +397,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         UnknownReturnType execute(String foo);
     }
     public void testUnknownReturnType() {
+        Compiler compiler = new Compiler(UnknownReturnType.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(UnknownReturnType.class, null, "1", emptyMap()));
+            scriptEngine.compile(compiler, null, "1", emptyMap()));
         assertEquals("Painless can only implement execute methods returning a whitelisted type but [" + UnknownReturnType.class.getName()
                 + "#execute] returns [" + UnknownReturnType.class.getName() + "] which isn't whitelisted.", e.getMessage());
     }
@@ -380,8 +409,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(UnknownArgTypeInArray[] foo);
     }
     public void testUnknownArgTypeInArray() {
+        Compiler compiler = new Compiler(UnknownArgTypeInArray.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(UnknownArgTypeInArray.class, null, "1", emptyMap()));
+            scriptEngine.compile(compiler, null, "1", emptyMap()));
         assertEquals("[foo] is of unknown type [" + UnknownArgTypeInArray.class.getName() + ". Painless interfaces can only accept "
                 + "arguments that are of whitelisted types.", e.getMessage());
     }
@@ -391,8 +421,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object execute(boolean foo);
     }
     public void testTwoExecuteMethods() {
+        Compiler compiler = new Compiler(TwoExecuteMethods.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(TwoExecuteMethods.class, null, "null", emptyMap()));
+            scriptEngine.compile(compiler, null, "null", emptyMap()));
         assertEquals("Painless can only implement interfaces that have a single method named [execute] but ["
                 + TwoExecuteMethods.class.getName() + "] has more than one.", e.getMessage());
     }
@@ -401,8 +432,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object something();
     }
     public void testBadMethod() {
+        Compiler compiler = new Compiler(BadMethod.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(BadMethod.class, null, "null", emptyMap()));
+            scriptEngine.compile(compiler, null, "null", emptyMap()));
         assertEquals("Painless can only implement methods named [execute] and [uses$argName] but [" + BadMethod.class.getName()
                 + "] contains a method named [something]", e.getMessage());
     }
@@ -413,8 +445,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         Object uses$foo();
     }
     public void testBadUsesReturn() {
+        Compiler compiler = new Compiler(BadUsesReturn.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(BadUsesReturn.class, null, "null", emptyMap()));
+            scriptEngine.compile(compiler, null, "null", emptyMap()));
         assertEquals("Painless can only implement uses$ methods that return boolean but [" + BadUsesReturn.class.getName()
                 + "#uses$foo] returns [java.lang.Object].", e.getMessage());
     }
@@ -425,8 +458,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         boolean uses$bar(boolean foo);
     }
     public void testBadUsesParameter() {
+        Compiler compiler = new Compiler(BadUsesParameter.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(BadUsesParameter.class, null, "null", emptyMap()));
+            scriptEngine.compile(compiler, null, "null", emptyMap()));
         assertEquals("Painless can only implement uses$ methods that do not take parameters but [" + BadUsesParameter.class.getName()
                 + "#uses$bar] does.", e.getMessage());
     }
@@ -437,8 +471,9 @@ public class ImplementInterfacesTests extends ScriptTestCase {
         boolean uses$baz();
     }
     public void testBadUsesName() {
+        Compiler compiler = new Compiler(BadUsesName.class, Definition.BUILTINS);
         Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
-            scriptEngine.compile(BadUsesName.class, null, "null", emptyMap()));
+            scriptEngine.compile(compiler, null, "null", emptyMap()));
         assertEquals("Painless can only implement uses$ methods that match a parameter name but [" + BadUsesName.class.getName()
                 + "#uses$baz] doesn't match any of [foo, bar].", e.getMessage());
     }

+ 4 - 1
modules/lang-painless/src/test/java/org/elasticsearch/painless/NeedsScoreTests.java

@@ -21,10 +21,12 @@ package org.elasticsearch.painless;
 
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.IndexService;
+import org.elasticsearch.script.ExecutableScript;
 import org.elasticsearch.script.SearchScript;
 import org.elasticsearch.search.lookup.SearchLookup;
 import org.elasticsearch.test.ESSingleNodeTestCase;
 
+import java.util.Arrays;
 import java.util.Collections;
 
 /**
@@ -36,7 +38,8 @@ public class NeedsScoreTests extends ESSingleNodeTestCase {
     public void testNeedsScores() {
         IndexService index = createIndex("test", Settings.EMPTY, "type", "d", "type=double");
 
-        PainlessScriptEngine service = new PainlessScriptEngine(Settings.EMPTY);
+        PainlessScriptEngine service = new PainlessScriptEngine(Settings.EMPTY,
+            Arrays.asList(SearchScript.CONTEXT, ExecutableScript.CONTEXT));
         SearchLookup lookup = new SearchLookup(index.mapperService(), index.fieldData(), null);
 
         SearchScript.Factory factory = service.compile(null, "1.2", SearchScript.CONTEXT, Collections.emptyMap());

+ 18 - 1
modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptTestCase.java

@@ -26,11 +26,17 @@ import org.elasticsearch.common.lucene.ScorerAware;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.painless.antlr.Walker;
 import org.elasticsearch.script.ExecutableScript;
+import org.elasticsearch.script.ScriptContext;
 import org.elasticsearch.script.ScriptException;
+import org.elasticsearch.script.SearchScript;
 import org.elasticsearch.test.ESTestCase;
 import org.junit.Before;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 import static org.hamcrest.Matchers.hasSize;
@@ -45,7 +51,7 @@ public abstract class ScriptTestCase extends ESTestCase {
 
     @Before
     public void setup() {
-        scriptEngine = new PainlessScriptEngine(scriptEngineSettings());
+        scriptEngine = new PainlessScriptEngine(scriptEngineSettings(), scriptContexts());
     }
 
     /**
@@ -55,6 +61,17 @@ public abstract class ScriptTestCase extends ESTestCase {
         return Settings.EMPTY;
     }
 
+    /**
+     * Script contexts used to build the script engine. Override to customize which script contexts are available.
+     */
+    protected Collection<ScriptContext<?>> scriptContexts() {
+        Collection<ScriptContext<?>> contexts = new ArrayList<>();
+        contexts.add(SearchScript.CONTEXT);
+        contexts.add(ExecutableScript.CONTEXT);
+
+        return contexts;
+    }
+
     /** Compiles and returns the result of {@code script} */
     public Object exec(String script) {
         return exec(script, null, true);