Ver Fonte

Merge pull request #18899 from rmuir/more_def_cleanup

fix bugs in operators and more improvements for the dynamic case
Robert Muir há 9 anos atrás
pai
commit
154d750e4b
26 ficheiros alterados com 1188 adições e 440 exclusões
  1. 5 0
      buildSrc/src/main/resources/forbidden/es-all-signatures.txt
  2. 0 5
      buildSrc/src/main/resources/forbidden/es-core-signatures.txt
  3. 4 0
      modules/lang-painless/src/main/java/org/elasticsearch/painless/AnalyzerCaster.java
  4. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java
  5. 230 76
      modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java
  6. 84 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java
  7. 24 12
      modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java
  8. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EBinary.java
  9. 18 6
      modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EChain.java
  10. 8 6
      modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EComp.java
  11. 3 3
      modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EUnary.java
  12. 87 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/AdditionTests.java
  13. 76 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/AndTests.java
  14. 0 314
      modules/lang-painless/src/test/java/org/elasticsearch/painless/CompoundAssignmentTests.java
  15. 83 15
      modules/lang-painless/src/test/java/org/elasticsearch/painless/DefBootstrapTests.java
  16. 34 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/DefOptimizationTests.java
  17. 78 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/DivisionTests.java
  18. 31 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/IncrementTests.java
  19. 46 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/MultiplicationTests.java
  20. 76 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/OrTests.java
  21. 46 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/RemainderTests.java
  22. 9 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptTestCase.java
  23. 104 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/ShiftTests.java
  24. 16 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/StringTests.java
  25. 48 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/SubtractionTests.java
  26. 76 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/XorTests.java

+ 5 - 0
buildSrc/src/main/resources/forbidden/es-all-signatures.txt

@@ -31,3 +31,8 @@ org.apache.lucene.index.IndexReader#getCombinedCoreAndDeletesKey()
 
 @defaultMessage Soon to be removed
 org.apache.lucene.document.FieldType#numericType()
+
+@defaultMessage Don't use MethodHandles in slow ways, dont be lenient in tests.
+# unfortunately, invoke() cannot be banned, because forbidden apis does not support signature polymorphic methods
+java.lang.invoke.MethodHandle#invokeWithArguments(java.lang.Object[])
+java.lang.invoke.MethodHandle#invokeWithArguments(java.util.List)

+ 0 - 5
buildSrc/src/main/resources/forbidden/es-core-signatures.txt

@@ -92,8 +92,3 @@ org.joda.time.DateTime#<init>(int, int, int, int, int, int)
 org.joda.time.DateTime#<init>(int, int, int, int, int, int, int)
 org.joda.time.DateTime#now()
 org.joda.time.DateTimeZone#getDefault()
-
-@defaultMessage Don't use MethodHandles in slow ways, except in tests.
-java.lang.invoke.MethodHandle#invoke(java.lang.Object[])
-java.lang.invoke.MethodHandle#invokeWithArguments(java.lang.Object[])
-java.lang.invoke.MethodHandle#invokeWithArguments(java.util.List)

+ 4 - 0
modules/lang-painless/src/main/java/org/elasticsearch/painless/AnalyzerCaster.java

@@ -826,6 +826,10 @@ public final class AnalyzerCaster {
         final Sort sort0 = from0.sort;
         final Sort sort1 = from1.sort;
 
+        if (sort0 == Sort.DEF || sort1 == Sort.DEF) {
+            return Definition.DEF_TYPE;
+        }
+
         if (sort0.bool || sort1.bool) {
             return Definition.BOOLEAN_TYPE;
         }

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/Def.java

@@ -138,7 +138,7 @@ public final class Def {
 
     /** Hack to rethrow unknown Exceptions from {@link MethodHandle#invokeExact}: */
     @SuppressWarnings("unchecked")
-    private static <T extends Throwable> void rethrow(Throwable t) throws T {
+    static <T extends Throwable> void rethrow(Throwable t) throws T {
         throw (T) t;
     }
     

+ 230 - 76
modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java

@@ -68,6 +68,21 @@ public final class DefBootstrap {
     public static final int BINARY_OPERATOR = 8;
     /** static bootstrap parameter indicating a shift operator, e.g. foo &gt;&gt; bar */
     public static final int SHIFT_OPERATOR = 9;
+    
+    // constants for the flags parameter of operators
+    /** 
+     * static bootstrap parameter indicating the binary operator allows nulls (e.g. == and +) 
+     * <p>
+     * requires additional {@link MethodHandles#catchException} guard, which will invoke
+     * the fallback if a null is encountered.
+     */
+    public static final int OPERATOR_ALLOWS_NULL = 1 << 0;
+    
+    /**
+     * static bootstrap parameter indicating the binary operator is part of compound assignment (e.g. +=).
+     * 
+     */
+    public static final int OPERATOR_COMPOUND_ASSIGNMENT = 1 << 1;
 
     /**
      * CallSite that implements the polymorphic inlining cache (PIC).
@@ -84,18 +99,15 @@ public final class DefBootstrap {
 
         PIC(Lookup lookup, String name, MethodType type, int flavor, Object[] args) {
             super(type);
+            if (type.parameterType(0) != Object.class) {
+                throw new BootstrapMethodError("The receiver type (1st arg) of invokedynamic descriptor must be Object.");
+            }
             this.lookup = lookup;
             this.name = name;
             this.flavor = flavor;
             this.args = args;
-            
-            // For operators use a monomorphic cache, fallback is fast.
-            // Just start with a depth of MAX-1, to keep it a constant.
-            if (flavor == UNARY_OPERATOR || flavor == BINARY_OPERATOR || flavor == SHIFT_OPERATOR) {
-                depth = MAX_DEPTH - 1;
-            }
 
-            final MethodHandle fallback = FALLBACK.bindTo(this)
+            MethodHandle fallback = FALLBACK.bindTo(this)
               .asCollector(Object[].class, type.parameterCount())
               .asType(type);
 
@@ -109,89 +121,178 @@ public final class DefBootstrap {
         static boolean checkClass(Class<?> clazz, Object receiver) {
             return receiver.getClass() == clazz;
         }
-        
-        /**
-         * guard method for inline caching: checks the receiver's class and the first argument
-         * are the same as the cached receiver and first argument.
-         */
-        static boolean checkBinary(Class<?> left, Class<?> right, Object leftObject, Object rightObject) {
-            return leftObject.getClass() == left && rightObject.getClass() == right;
-        }
-        
-        /**
-         * guard method for inline caching: checks the first argument is the same
-         * as the cached first argument.
-         */
-        static boolean checkBinaryArg(Class<?> left, Class<?> right, Object leftObject, Object rightObject) {
-            return rightObject.getClass() == right;
-        }
 
         /**
          * Does a slow lookup against the whitelist.
          */
-        private MethodHandle lookup(int flavor, String name, Object[] args) throws Throwable {
+        private MethodHandle lookup(int flavor, String name, Class<?> receiver, Object[] callArgs) throws Throwable {
             switch(flavor) {
                 case METHOD_CALL:
-                    return Def.lookupMethod(lookup, type(), args[0].getClass(), name, args, (Long) this.args[0]);
+                    return Def.lookupMethod(lookup, type(), receiver, name, callArgs, (Long) this.args[0]);
                 case LOAD:
-                    return Def.lookupGetter(args[0].getClass(), name);
+                    return Def.lookupGetter(receiver, name);
                 case STORE:
-                    return Def.lookupSetter(args[0].getClass(), name);
+                    return Def.lookupSetter(receiver, name);
                 case ARRAY_LOAD:
-                    return Def.lookupArrayLoad(args[0].getClass());
+                    return Def.lookupArrayLoad(receiver);
                 case ARRAY_STORE:
-                    return Def.lookupArrayStore(args[0].getClass());
+                    return Def.lookupArrayStore(receiver);
                 case ITERATOR:
-                    return Def.lookupIterator(args[0].getClass());
+                    return Def.lookupIterator(receiver);
                 case REFERENCE:
-                    return Def.lookupReference(lookup, (String) this.args[0], args[0].getClass(), name);
+                    return Def.lookupReference(lookup, (String) this.args[0], receiver, name);
+                default: throw new AssertionError();
+            }
+        }
+        
+        /**
+         * Creates the {@link MethodHandle} for the megamorphic call site
+         * using {@link ClassValue} and {@link MethodHandles#exactInvoker(MethodType)}:
+         * <p>
+         * TODO: Remove the variable args and just use {@code type()}!
+         */
+        private MethodHandle createMegamorphicHandle(final Object[] callArgs) throws Throwable {
+            final MethodType type = type();
+            final ClassValue<MethodHandle> megamorphicCache = new ClassValue<MethodHandle>() {
+                @Override
+                protected MethodHandle computeValue(Class<?> receiverType) {
+                    // it's too stupid that we cannot throw checked exceptions... (use rethrow puzzler):
+                    try {
+                        return lookup(flavor, name, receiverType, callArgs).asType(type);
+                    } catch (Throwable t) {
+                        Def.rethrow(t);
+                        throw new AssertionError();
+                    }
+                }
+            };
+            MethodHandle cacheLookup = MEGAMORPHIC_LOOKUP.bindTo(megamorphicCache);
+            cacheLookup = MethodHandles.dropArguments(cacheLookup,
+                    1, type.parameterList().subList(1, type.parameterCount()));
+            return MethodHandles.foldArguments(MethodHandles.exactInvoker(type), cacheLookup);            
+        }
+
+        /**
+         * Called when a new type is encountered (or, when we have encountered more than {@code MAX_DEPTH}
+         * types at this call site and given up on caching using this fallback and we switch to a
+         * megamorphic cache using {@link ClassValue}).
+         */
+        @SuppressForbidden(reason = "slow path")
+        Object fallback(final Object[] callArgs) throws Throwable {
+            if (depth >= MAX_DEPTH) {
+                // we revert the whole cache and build a new megamorphic one
+                final MethodHandle target = this.createMegamorphicHandle(callArgs);
+                
+                setTarget(target);
+                return target.invokeWithArguments(callArgs);                    
+            } else {
+                final Class<?> receiver = callArgs[0].getClass();
+                final MethodHandle target = lookup(flavor, name, receiver, callArgs).asType(type());
+    
+                MethodHandle test = CHECK_CLASS.bindTo(receiver);
+                MethodHandle guard = MethodHandles.guardWithTest(test, target, getTarget());
+                
+                depth++;
+    
+                setTarget(guard);
+                return target.invokeWithArguments(callArgs);
+            }
+        }
+
+        private static final MethodHandle CHECK_CLASS;
+        private static final MethodHandle FALLBACK;
+        private static final MethodHandle MEGAMORPHIC_LOOKUP;
+        static {
+            final Lookup lookup = MethodHandles.lookup();
+            final Lookup publicLookup = MethodHandles.publicLookup();
+            try {
+                CHECK_CLASS = lookup.findStatic(lookup.lookupClass(), "checkClass",
+                                              MethodType.methodType(boolean.class, Class.class, Object.class));
+                FALLBACK = lookup.findVirtual(lookup.lookupClass(), "fallback",
+                        MethodType.methodType(Object.class, Object[].class));
+                MethodHandle mh = publicLookup.findVirtual(ClassValue.class, "get",
+                        MethodType.methodType(Object.class, Class.class));
+                mh = MethodHandles.filterArguments(mh, 1, 
+                        publicLookup.findVirtual(Object.class, "getClass", MethodType.methodType(Class.class)));
+                MEGAMORPHIC_LOOKUP = mh.asType(mh.type().changeReturnType(MethodHandle.class));
+            } catch (ReflectiveOperationException e) {
+                throw new AssertionError(e);
+            }
+        }
+    }
+    
+    /**
+     * CallSite that implements the monomorphic inlining cache (for operators).
+     */
+    static final class MIC extends MutableCallSite {
+        private boolean initialized;
+        
+        private final String name;
+        private final int flavor;
+        private final int flags;
+
+        MIC(String name, MethodType type, int flavor, int flags) {
+            super(type);
+            this.name = name;
+            this.flavor = flavor;
+            this.flags = flags;
+            
+            MethodHandle fallback = FALLBACK.bindTo(this)
+              .asCollector(Object[].class, type.parameterCount())
+              .asType(type);
+
+            setTarget(fallback);
+        }
+        
+        /**
+         * Does a slow lookup for the operator
+         */
+        private MethodHandle lookup(Object[] args) throws Throwable {
+            switch(flavor) {
                 case UNARY_OPERATOR:
                 case SHIFT_OPERATOR:
                     // shifts are treated as unary, as java allows long arguments without a cast (but bits are ignored)
-                    return DefMath.lookupUnary(args[0].getClass(), name);
+                    MethodHandle unary = DefMath.lookupUnary(args[0].getClass(), name);
+                    if ((flags & OPERATOR_COMPOUND_ASSIGNMENT) != 0) {
+                        unary = DefMath.cast(args[0].getClass(), unary);
+                    }
+                    return unary;
                 case BINARY_OPERATOR:
                     if (args[0] == null || args[1] == null) {
-                        return getGeneric(flavor, name); // can handle nulls
+                        return lookupGeneric(); // can handle nulls, casts if supported
                     } else {
-                        return DefMath.lookupBinary(args[0].getClass(), args[1].getClass(), name);
+                        MethodHandle binary = DefMath.lookupBinary(args[0].getClass(), args[1].getClass(), name);
+                        if ((flags & OPERATOR_COMPOUND_ASSIGNMENT) != 0) {
+                            binary = DefMath.cast(args[0].getClass(), binary);
+                        }
+                        return binary;
                     }
                 default: throw new AssertionError();
             }
         }
         
-        /**
-         * Installs a permanent, generic solution that works with any parameter types, if possible.
-         */
-        private MethodHandle getGeneric(int flavor, String name) throws Throwable {
-            switch(flavor) {
-                case UNARY_OPERATOR:
-                case BINARY_OPERATOR:
-                case SHIFT_OPERATOR:
-                    return DefMath.lookupGeneric(name);
-                default:
-                    return null;
+        private MethodHandle lookupGeneric() throws Throwable {
+            if ((flags & OPERATOR_COMPOUND_ASSIGNMENT) != 0) {
+                return DefMath.lookupGenericWithCast(name);
+            } else {
+                return DefMath.lookupGeneric(name);
             }
         }
-
+        
         /**
          * Called when a new type is encountered (or, when we have encountered more than {@code MAX_DEPTH}
          * types at this call site and given up on caching).
          */
         @SuppressForbidden(reason = "slow path")
         Object fallback(Object[] args) throws Throwable {
-            if (depth >= MAX_DEPTH) {
+            if (initialized) {
                 // caching defeated
-                MethodHandle generic = getGeneric(flavor, name);
-                if (generic != null) {
-                    setTarget(generic.asType(type()));
-                    return generic.invokeWithArguments(args);
-                } else {
-                    return lookup(flavor, name, args).invokeWithArguments(args);
-                }
+                MethodHandle generic = lookupGeneric();
+                setTarget(generic.asType(type()));
+                return generic.invokeWithArguments(args);
             }
             
             final MethodType type = type();
-            final MethodHandle target = lookup(flavor, name, args).asType(type);
+            final MethodHandle target = lookup(args).asType(type);
 
             final MethodHandle test;
             if (flavor == BINARY_OPERATOR || flavor == SHIFT_OPERATOR) {
@@ -200,24 +301,25 @@ public final class DefBootstrap {
                 Class<?> clazz1 = args[1] == null ? null : args[1].getClass();
                 if (type.parameterType(1) != Object.class) {
                     // case 1: only the receiver is unknown, just check that
-                    MethodHandle unaryTest = CHECK_CLASS.bindTo(clazz0);
+                    MethodHandle unaryTest = CHECK_LHS.bindTo(clazz0);
                     test = unaryTest.asType(unaryTest.type()
                                             .changeParameterType(0, type.parameterType(0)));
                 } else if (type.parameterType(0) != Object.class) {
                     // case 2: only the argument is unknown, just check that
-                    MethodHandle unaryTest = CHECK_BINARY_ARG.bindTo(clazz0).bindTo(clazz1);
+                    MethodHandle unaryTest = CHECK_RHS.bindTo(clazz0).bindTo(clazz1);
                     test = unaryTest.asType(unaryTest.type()
                                             .changeParameterType(0, type.parameterType(0))
                                             .changeParameterType(1, type.parameterType(1)));
                 } else {
                     // case 3: check both receiver and argument
-                    MethodHandle binaryTest = CHECK_BINARY.bindTo(clazz0).bindTo(clazz1);
+                    MethodHandle binaryTest = CHECK_BOTH.bindTo(clazz0).bindTo(clazz1);
                     test = binaryTest.asType(binaryTest.type()
                                             .changeParameterType(0, type.parameterType(0))
                                             .changeParameterType(1, type.parameterType(1)));
                 }
             } else {
-                MethodHandle receiverTest = CHECK_CLASS.bindTo(args[0].getClass());
+                // unary operator
+                MethodHandle receiverTest = CHECK_LHS.bindTo(args[0].getClass());
                 test = receiverTest.asType(receiverTest.type()
                                         .changeParameterType(0, type.parameterType(0)));
             }
@@ -225,29 +327,55 @@ public final class DefBootstrap {
             MethodHandle guard = MethodHandles.guardWithTest(test, target, getTarget());
             // very special cases, where even the receiver can be null (see JLS rules for string concat)
             // we wrap + with an NPE catcher, and use our generic method in that case.
-            if (flavor == BINARY_OPERATOR && "add".equals(name) || "eq".equals(name)) {
-                MethodHandle handler = MethodHandles.dropArguments(getGeneric(flavor, name).asType(type()), 0, NullPointerException.class);
+            if (flavor == BINARY_OPERATOR && (flags & OPERATOR_ALLOWS_NULL) != 0) {
+                MethodHandle handler = MethodHandles.dropArguments(lookupGeneric().asType(type()), 
+                                                                   0, 
+                                                                   NullPointerException.class);
                 guard = MethodHandles.catchException(guard, NullPointerException.class, handler);
             }
             
-            depth++;
+            initialized = true;
 
             setTarget(guard);
             return target.invokeWithArguments(args);
         }
+        
+        /**
+         * guard method for inline caching: checks the receiver's class is the same
+         * as the cached class
+         */
+        static boolean checkLHS(Class<?> clazz, Object leftObject) {
+            return leftObject.getClass() == clazz;
+        }
 
-        private static final MethodHandle CHECK_CLASS;
-        private static final MethodHandle CHECK_BINARY;
-        private static final MethodHandle CHECK_BINARY_ARG;
+        /**
+         * guard method for inline caching: checks the first argument is the same
+         * as the cached first argument.
+         */
+        static boolean checkRHS(Class<?> left, Class<?> right, Object leftObject, Object rightObject) {
+            return rightObject.getClass() == right;
+        }
+        
+        /**
+         * guard method for inline caching: checks the receiver's class and the first argument
+         * are the same as the cached receiver and first argument.
+         */
+        static boolean checkBoth(Class<?> left, Class<?> right, Object leftObject, Object rightObject) {
+            return leftObject.getClass() == left && rightObject.getClass() == right;
+        }
+        
+        private static final MethodHandle CHECK_LHS;
+        private static final MethodHandle CHECK_RHS;
+        private static final MethodHandle CHECK_BOTH;
         private static final MethodHandle FALLBACK;
         static {
             final Lookup lookup = MethodHandles.lookup();
             try {
-                CHECK_CLASS = lookup.findStatic(lookup.lookupClass(), "checkClass",
+                CHECK_LHS = lookup.findStatic(lookup.lookupClass(), "checkLHS",
                                               MethodType.methodType(boolean.class, Class.class, Object.class));
-                CHECK_BINARY = lookup.findStatic(lookup.lookupClass(), "checkBinary",
+                CHECK_RHS = lookup.findStatic(lookup.lookupClass(), "checkRHS",
                                               MethodType.methodType(boolean.class, Class.class, Class.class, Object.class, Object.class));
-                CHECK_BINARY_ARG = lookup.findStatic(lookup.lookupClass(), "checkBinaryArg",
+                CHECK_BOTH = lookup.findStatic(lookup.lookupClass(), "checkBoth",
                                               MethodType.methodType(boolean.class, Class.class, Class.class, Object.class, Object.class));
                 FALLBACK = lookup.findVirtual(lookup.lookupClass(), "fallback",
                                               MethodType.methodType(Object.class, Object[].class));
@@ -268,6 +396,7 @@ public final class DefBootstrap {
     public static CallSite bootstrap(Lookup lookup, String name, MethodType type, int flavor, Object... args) {
         // validate arguments
         switch(flavor) {
+            // "function-call" like things get a polymorphic cache
             case METHOD_CALL:
                 if (args.length != 1) {
                     throw new BootstrapMethodError("Invalid number of parameters for method call");
@@ -279,7 +408,16 @@ public final class DefBootstrap {
                 if (Long.bitCount(recipe) > type.parameterCount()) {
                     throw new BootstrapMethodError("Illegal recipe for method call: too many bits");
                 }
-                break;
+                return new PIC(lookup, name, type, flavor, args);
+            case LOAD:
+            case STORE:
+            case ARRAY_LOAD:
+            case ARRAY_STORE:
+            case ITERATOR:
+                if (args.length > 0) {
+                    throw new BootstrapMethodError("Illegal static bootstrap parameters for flavor: " + flavor);
+                }
+                return new PIC(lookup, name, type, flavor, args);
             case REFERENCE:
                 if (args.length != 1) {
                     throw new BootstrapMethodError("Invalid number of parameters for reference call");
@@ -287,14 +425,30 @@ public final class DefBootstrap {
                 if (args[0] instanceof String == false) {
                     throw new BootstrapMethodError("Illegal parameter for reference call: " + args[0]);
                 }
-                break;
-            default:
-                if (args.length > 0) {
-                    throw new BootstrapMethodError("Illegal static bootstrap parameters for flavor: " + flavor);
+                return new PIC(lookup, name, type, flavor, args);
+
+            // operators get monomorphic cache, with a generic impl for a fallback
+            case UNARY_OPERATOR:
+            case SHIFT_OPERATOR:
+            case BINARY_OPERATOR:
+                if (args.length != 1) {
+                    throw new BootstrapMethodError("Invalid number of parameters for operator call");
+                }
+                if (args[0] instanceof Integer == false) {
+                    throw new BootstrapMethodError("Illegal parameter for reference call: " + args[0]);
+                }
+                int flags = (int)args[0];
+                if ((flags & OPERATOR_ALLOWS_NULL) != 0 && flavor != BINARY_OPERATOR) {
+                    // we just don't need it anywhere else.
+                    throw new BootstrapMethodError("This parameter is only supported for BINARY_OPERATORs");
                 }
-                break;
+                if ((flags & OPERATOR_COMPOUND_ASSIGNMENT) != 0 && flavor != BINARY_OPERATOR) {
+                    // we just don't need it anywhere else.
+                    throw new BootstrapMethodError("This parameter is only supported for BINARY_OPERATORs");
+                }
+                return new MIC(name, type, flavor, flags);
+            default:
+                throw new BootstrapMethodError("Illegal static bootstrap parameter for flavor: " + flavor);
         }
-        return new PIC(lookup, name, type, flavor, args);
     }
-
 }

+ 84 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/DefMath.java

@@ -1133,7 +1133,7 @@ public class DefMath {
         return handle;
     }
     
-    /** Returns an appropriate method handle for a binary operator, based only promotion of the LHS and RHS arguments */
+    /** Returns an appropriate method handle for a binary operator, based on promotion of the LHS and RHS arguments */
     public static MethodHandle lookupBinary(Class<?> classA, Class<?> classB, String name) {
         MethodHandle handle = TYPE_OP_MAPPING.get(promote(promote(unbox(classA)), promote(unbox(classB)))).get(name);
         if (handle == null) {
@@ -1146,4 +1146,87 @@ public class DefMath {
     public static MethodHandle lookupGeneric(String name) {
         return TYPE_OP_MAPPING.get(Object.class).get(name);
     }
+    
+    
+    /**
+     * Slow dynamic cast: casts {@code returnValue} to the runtime type of {@code lhs}
+     * based upon inspection. If {@code lhs} is null, no cast takes place.
+     * This is used for the generic fallback case of compound assignment.
+     */
+    static Object dynamicCast(Object returnValue, Object lhs) {
+        if (lhs != null) {
+            Class<?> c = lhs.getClass();
+            if (c == returnValue.getClass()) {
+                return returnValue;
+            }
+            if (c == Integer.class) {
+                return getNumber(returnValue).intValue();
+            } else if (c == Long.class) {
+                return getNumber(returnValue).longValue();
+            } else if (c == Double.class) {
+                return getNumber(returnValue).doubleValue();
+            } else if (c == Float.class) {
+                return getNumber(returnValue).floatValue();
+            } else if (c == Short.class) {
+                return getNumber(returnValue).shortValue();
+            } else if (c == Byte.class) {
+                return getNumber(returnValue).byteValue();
+            } else if (c == Character.class) {
+                return (char) getNumber(returnValue).intValue();
+            }
+            return lhs.getClass().cast(returnValue);
+        } else {
+            return returnValue;
+        }
+    }
+    
+    /** Slowly returns a Number for o. Just for supporting dynamicCast */
+    static Number getNumber(Object o) {
+        if (o instanceof Number) {
+            return (Number)o;
+        } else if (o instanceof Character) {
+            return Integer.valueOf((char)o);
+        } else {
+            throw new ClassCastException("Cannot convert [" + o.getClass() + "] to a Number");
+        }
+    }
+    
+    private static final MethodHandle DYNAMIC_CAST;
+    static {
+        final Lookup lookup = MethodHandles.lookup();
+        try {
+            DYNAMIC_CAST = lookup.findStatic(lookup.lookupClass(), 
+                                             "dynamicCast", 
+                                             MethodType.methodType(Object.class, Object.class, Object.class));
+        } catch (ReflectiveOperationException e) {
+            throw new AssertionError(e);
+        }
+    }
+
+    /** Looks up generic method, with a dynamic cast to the receiver's type. (compound assignment) */
+    public static MethodHandle lookupGenericWithCast(String name) {
+        MethodHandle generic = lookupGeneric(name);
+        // adapt dynamic cast to the generic method
+        MethodHandle cast = DYNAMIC_CAST.asType(MethodType.methodType(generic.type().returnType(), 
+                                                                      generic.type().returnType(),
+                                                                      generic.type().parameterType(0)));
+        // drop the RHS parameter
+        cast = MethodHandles.dropArguments(cast, 2, generic.type().parameterType(1));
+        // combine: f(x,y) -> g(f(x,y), x, y);
+        return MethodHandles.foldArguments(cast, generic);
+    }
+    
+    /** Forces a cast to class A for target (only if types differ) */
+    public static MethodHandle cast(Class<?> classA, MethodHandle target) {
+        MethodType newType = MethodType.methodType(classA).unwrap();
+        MethodType targetType = MethodType.methodType(target.type().returnType()).unwrap();
+        
+        if (newType.returnType() == targetType.returnType()) {
+            return target; // no conversion
+        }
+        
+        // this is safe for our uses of it here only, because we change just the return value,
+        // the original method itself does all the type checks correctly.
+        return MethodHandles.explicitCastArguments(target, target.type().changeReturnType(newType.returnType()));
+    }
 }

+ 24 - 12
modules/lang-painless/src/main/java/org/elasticsearch/painless/MethodWriter.java

@@ -263,43 +263,55 @@ public final class MethodWriter extends GeneratorAdapter {
     }
 
     /** Writes a dynamic binary instruction: returnType, lhs, and rhs can be different */
-    public void writeDynamicBinaryInstruction(Location location, Type returnType, Type lhs, Type rhs, Operation operation) {
+    public void writeDynamicBinaryInstruction(Location location, Type returnType, Type lhs, Type rhs, 
+                                              Operation operation, boolean compoundAssignment) {
         org.objectweb.asm.Type methodType = org.objectweb.asm.Type.getMethodType(returnType.type, lhs.type, rhs.type);
         String descriptor = methodType.getDescriptor();
         
+        int flags = 0;
+        if (compoundAssignment) {
+            flags |= DefBootstrap.OPERATOR_COMPOUND_ASSIGNMENT;
+        }
         switch (operation) {
             case MUL:
-                invokeDynamic("mul", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); 
+                invokeDynamic("mul", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, flags); 
                 break;
             case DIV:
-                invokeDynamic("div", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); 
+                invokeDynamic("div", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, flags); 
                 break;
             case REM:
-                invokeDynamic("rem", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); 
+                invokeDynamic("rem", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, flags); 
                 break;
             case ADD:
-                invokeDynamic("add", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); 
+                // if either side is primitive, then the + operator should always throw NPE on null,
+                // so we don't need a special NPE guard.
+                // otherwise, we need to allow nulls for possible string concatenation.
+                boolean hasPrimitiveArg = lhs.clazz.isPrimitive() || rhs.clazz.isPrimitive();
+                if (!hasPrimitiveArg) {
+                    flags |= DefBootstrap.OPERATOR_ALLOWS_NULL;
+                }
+                invokeDynamic("add", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, flags);
                 break;
             case SUB:
-                invokeDynamic("sub", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR); 
+                invokeDynamic("sub", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, flags); 
                 break;
             case LSH:
-                invokeDynamic("lsh", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR); 
+                invokeDynamic("lsh", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR, flags);
                 break;
             case USH:
-                invokeDynamic("ush", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR); 
+                invokeDynamic("ush", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR, flags); 
                 break;
             case RSH:
-                invokeDynamic("rsh", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR); 
+                invokeDynamic("rsh", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.SHIFT_OPERATOR, flags); 
                 break;
             case BWAND: 
-                invokeDynamic("and", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                invokeDynamic("and", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, flags);
                 break;
             case XOR:   
-                invokeDynamic("xor", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                invokeDynamic("xor", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, flags);
                 break;
             case BWOR:  
-                invokeDynamic("or", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                invokeDynamic("or", descriptor, DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, flags);
                 break;
             default:
                 throw location.createError(new IllegalStateException("Illegal tree structure."));

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EBinary.java

@@ -612,7 +612,7 @@ public final class EBinary extends AExpression {
             right.write(writer);
 
             if (promote.sort == Sort.DEF || (shiftDistance != null && shiftDistance.sort == Sort.DEF)) {
-                writer.writeDynamicBinaryInstruction(location, actual, left.actual, right.actual, operation);
+                writer.writeDynamicBinaryInstruction(location, actual, left.actual, right.actual, operation, false);
             } else {
                 writer.writeBinaryInstruction(location, actual, operation);
             }

+ 18 - 6
modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EChain.java

@@ -45,6 +45,7 @@ public final class EChain extends AExpression {
 
     boolean cat = false;
     Type promote = null;
+    Type shiftDistance; // for shifts, the RHS is promoted independently
     Cast there = null;
     Cast back = null;
     
@@ -163,6 +164,7 @@ public final class EChain extends AExpression {
         ALink last = links.get(links.size() - 1);
 
         expression.analyze(variables);
+        boolean shift = false;
 
         if (operation == Operation.MUL) {
             promote = AnalyzerCaster.promoteNumeric(last.after, expression.actual, true);
@@ -176,10 +178,16 @@ public final class EChain extends AExpression {
             promote = AnalyzerCaster.promoteNumeric(last.after, expression.actual, true);
         } else if (operation == Operation.LSH) {
             promote = AnalyzerCaster.promoteNumeric(last.after, false);
+            shiftDistance = AnalyzerCaster.promoteNumeric(expression.actual, false);
+            shift = true;
         } else if (operation == Operation.RSH) {
             promote = AnalyzerCaster.promoteNumeric(last.after, false);
+            shiftDistance = AnalyzerCaster.promoteNumeric(expression.actual, false);
+            shift = true;
         } else if (operation == Operation.USH) {
             promote = AnalyzerCaster.promoteNumeric(last.after, false);
+            shiftDistance = AnalyzerCaster.promoteNumeric(expression.actual, false);
+            shift = true;
         } else if (operation == Operation.BWAND) {
             promote = AnalyzerCaster.promoteXor(last.after, expression.actual);
         } else if (operation == Operation.XOR) {
@@ -190,7 +198,7 @@ public final class EChain extends AExpression {
             throw createError(new IllegalStateException("Illegal tree structure."));
         }
 
-        if (promote == null) {
+        if (promote == null || (shift && shiftDistance == null)) {
             throw createError(new ClassCastException("Cannot apply compound assignment " +
                 "[" + operation.symbol + "=] to types [" + last.after + "] and [" + expression.actual + "]."));
         }
@@ -204,9 +212,13 @@ public final class EChain extends AExpression {
             }
 
             expression.expected = expression.actual;
-        } else if (operation == Operation.LSH || operation == Operation.RSH || operation == Operation.USH) {
-            expression.expected = Definition.INT_TYPE;
-            expression.explicit = true;
+        } else if (shift) {
+            if (shiftDistance.sort == Sort.LONG) {
+                expression.expected = Definition.INT_TYPE;
+                expression.explicit = true;   
+            } else {
+                expression.expected = shiftDistance;
+            }
         } else {
             expression.expected = promote;
         }
@@ -335,11 +347,11 @@ public final class EChain extends AExpression {
                                                                                  // to the promotion type between the lhs and rhs types
                     expression.write(writer);                                    // write the bytecode for the rhs expression
                     // XXX: fix these types, but first we need def compound assignment tests.
-                    // (and also corner cases such as shifts). its tricky here as there are possibly explicit casts, too.
+                    // its tricky here as there are possibly explicit casts, too.
                     // write the operation instruction for compound assignment
                     if (promote.sort == Sort.DEF) {
                         writer.writeDynamicBinaryInstruction(location, promote, 
-                            Definition.DEF_TYPE, Definition.DEF_TYPE, operation);
+                            Definition.DEF_TYPE, Definition.DEF_TYPE, operation, true);
                     } else {
                         writer.writeBinaryInstruction(location, promote, operation);
                     }

+ 8 - 6
modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EComp.java

@@ -497,7 +497,8 @@ public final class EComp extends AExpression {
                     if (right.isNull) {
                         writer.ifNull(jump);
                     } else if (!left.isNull && (operation == Operation.EQ || operation == Operation.NE)) {
-                        writer.invokeDynamic("eq", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                        writer.invokeDynamic("eq", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR,
+                                                                                                     DefBootstrap.OPERATOR_ALLOWS_NULL);
                         writejump = false;
                     } else {
                         writer.ifCmp(promotedType.type, MethodWriter.EQ, jump);
@@ -506,22 +507,23 @@ public final class EComp extends AExpression {
                     if (right.isNull) {
                         writer.ifNonNull(jump);
                     } else if (!left.isNull && (operation == Operation.EQ || operation == Operation.NE)) {
-                        writer.invokeDynamic("eq", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                        writer.invokeDynamic("eq", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR,
+                                                                                                     DefBootstrap.OPERATOR_ALLOWS_NULL);
                         writer.ifZCmp(MethodWriter.EQ, jump);
                     } else {
                         writer.ifCmp(promotedType.type, MethodWriter.NE, jump);
                     }
                 } else if (lt) {
-                    writer.invokeDynamic("lt", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                    writer.invokeDynamic("lt", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, 0);
                     writejump = false;
                 } else if (lte) {
-                    writer.invokeDynamic("lte", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                    writer.invokeDynamic("lte", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, 0);
                     writejump = false;
                 } else if (gt) {
-                    writer.invokeDynamic("gt", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                    writer.invokeDynamic("gt", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, 0);
                     writejump = false;
                 } else if (gte) {
-                    writer.invokeDynamic("gte", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR);
+                    writer.invokeDynamic("gte", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.BINARY_OPERATOR, 0);
                     writejump = false;
                 } else {
                     throw createError(new IllegalStateException("Illegal tree structure."));

+ 3 - 3
modules/lang-painless/src/main/java/org/elasticsearch/painless/node/EUnary.java

@@ -205,7 +205,7 @@ public final class EUnary extends AExpression {
             if (operation == Operation.BWNOT) {
                 if (sort == Sort.DEF) {
                     org.objectweb.asm.Type descriptor = org.objectweb.asm.Type.getMethodType(actual.type, child.actual.type);
-                    writer.invokeDynamic("not", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR);
+                    writer.invokeDynamic("not", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR, 0);
                 } else {
                     if (sort == Sort.INT) {
                         writer.push(-1);
@@ -220,14 +220,14 @@ public final class EUnary extends AExpression {
             } else if (operation == Operation.SUB) {
                 if (sort == Sort.DEF) {
                     org.objectweb.asm.Type descriptor = org.objectweb.asm.Type.getMethodType(actual.type, child.actual.type);
-                    writer.invokeDynamic("neg", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR);
+                    writer.invokeDynamic("neg", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR, 0);
                 } else {
                     writer.math(MethodWriter.NEG, actual.type);
                 }
             } else if (operation == Operation.ADD) {
                 if (sort == Sort.DEF) {
                     org.objectweb.asm.Type descriptor = org.objectweb.asm.Type.getMethodType(actual.type, child.actual.type);
-                    writer.invokeDynamic("plus", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR);
+                    writer.invokeDynamic("plus", descriptor.getDescriptor(), DEF_BOOTSTRAP_HANDLE, DefBootstrap.UNARY_OPERATOR, 0);
                 } 
             } else {
                 throw createError(new IllegalStateException("Illegal tree structure."));

+ 87 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/AdditionTests.java

@@ -369,4 +369,91 @@ public class AdditionTests extends ScriptTestCase {
         assertEquals(2D, exec("def x = (float)1; double y = (double)1; return x + y"));
         assertEquals(2D, exec("def x = (double)1; double y = (double)1; return x + y"));
     }
+    
+    public void testDefNulls() {
+        expectScriptThrows(NullPointerException.class, () -> {
+            exec("def x = null; int y = 1; return x + y"); 
+        });
+        expectScriptThrows(NullPointerException.class, () -> {
+            exec("int x = 1; def y = null; return x + y"); 
+        });
+        expectScriptThrows(NullPointerException.class, () -> {
+            exec("def x = null; def y = 1; return x + y"); 
+        });
+    }
+    
+    public void testCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("byte x = 5; x += 10; return x;"));
+        assertEquals((byte) -5, exec("byte x = 5; x += -10; return x;"));
+
+        // short
+        assertEquals((short) 15, exec("short x = 5; x += 10; return x;"));
+        assertEquals((short) -5, exec("short x = 5; x += -10; return x;"));
+        // char
+        assertEquals((char) 15, exec("char x = 5; x += 10; return x;"));
+        assertEquals((char) 5, exec("char x = 10; x += -5; return x;"));
+        // int
+        assertEquals(15, exec("int x = 5; x += 10; return x;"));
+        assertEquals(-5, exec("int x = 5; x += -10; return x;"));
+        // long
+        assertEquals(15L, exec("long x = 5; x += 10; return x;"));
+        assertEquals(-5L, exec("long x = 5; x += -10; return x;"));
+        // float
+        assertEquals(15F, exec("float x = 5f; x += 10; return x;"));
+        assertEquals(-5F, exec("float x = 5f; x += -10; return x;"));
+        // double
+        assertEquals(15D, exec("double x = 5.0; x += 10; return x;"));
+        assertEquals(-5D, exec("double x = 5.0; x += -10; return x;"));
+    }
+    
+    public void testDefCompoundAssignmentLHS() {
+        // byte
+        assertEquals((byte) 15, exec("def x = (byte)5; x += 10; return x;"));
+        assertEquals((byte) -5, exec("def x = (byte)5; x += -10; return x;"));
+
+        // short
+        assertEquals((short) 15, exec("def x = (short)5; x += 10; return x;"));
+        assertEquals((short) -5, exec("def x = (short)5; x += -10; return x;"));
+        // char
+        assertEquals((char) 15, exec("def x = (char)5; x += 10; return x;"));
+        assertEquals((char) 5, exec("def x = (char)10; x += -5; return x;"));
+        // int
+        assertEquals(15, exec("def x = 5; x += 10; return x;"));
+        assertEquals(-5, exec("def x = 5; x += -10; return x;"));
+        // long
+        assertEquals(15L, exec("def x = 5L; x += 10; return x;"));
+        assertEquals(-5L, exec("def x = 5L; x += -10; return x;"));
+        // float
+        assertEquals(15F, exec("def x = 5f; x += 10; return x;"));
+        assertEquals(-5F, exec("def x = 5f; x += -10; return x;"));
+        // double
+        assertEquals(15D, exec("def x = 5.0; x += 10; return x;"));
+        assertEquals(-5D, exec("def x = 5.0; x += -10; return x;"));
+    }
+    
+    public void testDefCompoundAssignmentRHS() {
+        // byte
+        assertEquals((byte) 15, exec("byte x = 5; def y = 10; x += y; return x;"));
+        assertEquals((byte) -5, exec("byte x = 5; def y = -10; x += y; return x;"));
+
+        // short
+        assertEquals((short) 15, exec("short x = 5; def y = 10; x += y; return x;"));
+        assertEquals((short) -5, exec("short x = 5; def y = -10; x += y; return x;"));
+        // char
+        assertEquals((char) 15, exec("char x = 5; def y = 10; x += y; return x;"));
+        assertEquals((char) 5, exec("char x = 10; def y = -5; x += y; return x;"));
+        // int
+        assertEquals(15, exec("int x = 5; def y = 10; x += y; return x;"));
+        assertEquals(-5, exec("int x = 5; def y = -10; x += y; return x;"));
+        // long
+        assertEquals(15L, exec("long x = 5; def y = 10; x += y; return x;"));
+        assertEquals(-5L, exec("long x = 5; def y = -10; x += y; return x;"));
+        // float
+        assertEquals(15F, exec("float x = 5f; def y = 10; x += y; return x;"));
+        assertEquals(-5F, exec("float x = 5f; def y = -10; x += y; return x;"));
+        // double
+        assertEquals(15D, exec("double x = 5.0; def y = 10; x += y; return x;"));
+        assertEquals(-5D, exec("double x = 5.0; def y = -10; x += y; return x;"));
+    }
 }

+ 76 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/AndTests.java

@@ -208,4 +208,80 @@ public class AndTests extends ScriptTestCase {
         assertEquals(false, exec("def x = false; boolean y = true; return x & y"));
         assertEquals(false, exec("def x = false; boolean y = false; return x & y"));
     }
+    
+    public void testCompoundAssignment() {
+        // boolean
+        assertEquals(true, exec("boolean x = true; x &= true; return x;"));
+        assertEquals(false, exec("boolean x = true; x &= false; return x;"));
+        assertEquals(false, exec("boolean x = false; x &= true; return x;"));
+        assertEquals(false, exec("boolean x = false; x &= false; return x;"));
+        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] &= true; return x[0];"));
+        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] &= false; return x[0];"));
+        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] &= true; return x[0];"));
+        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] &= false; return x[0];"));
+
+        // byte
+        assertEquals((byte) (13 & 14), exec("byte x = 13; x &= 14; return x;"));
+        // short
+        assertEquals((short) (13 & 14), exec("short x = 13; x &= 14; return x;"));
+        // char
+        assertEquals((char) (13 & 14), exec("char x = 13; x &= 14; return x;"));
+        // int
+        assertEquals(13 & 14, exec("int x = 13; x &= 14; return x;"));
+        // long
+        assertEquals((long) (13 & 14), exec("long x = 13L; x &= 14; return x;"));
+    }
+    
+    public void testBogusCompoundAssignment() {
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("float x = 4; int y = 1; x &= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("double x = 4; int y = 1; x &= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; float y = 1; x &= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; double y = 1; x &= y");
+        });
+    }
+    
+    public void testDefCompoundAssignment() {
+        // boolean
+        assertEquals(true, exec("def x = true; x &= true; return x;"));
+        assertEquals(false, exec("def x = true; x &= false; return x;"));
+        assertEquals(false, exec("def x = false; x &= true; return x;"));
+        assertEquals(false, exec("def x = false; x &= false; return x;"));
+        assertEquals(true, exec("def[] x = new def[1]; x[0] = true; x[0] &= true; return x[0];"));
+        assertEquals(false, exec("def[] x = new def[1]; x[0] = true; x[0] &= false; return x[0];"));
+        assertEquals(false, exec("def[] x = new def[1]; x[0] = false; x[0] &= true; return x[0];"));
+        assertEquals(false, exec("def[] x = new def[1]; x[0] = false; x[0] &= false; return x[0];"));
+
+        // byte
+        assertEquals((byte) (13 & 14), exec("def x = (byte)13; x &= 14; return x;"));
+        // short
+        assertEquals((short) (13 & 14), exec("def x = (short)13; x &= 14; return x;"));
+        // char
+        assertEquals((char) (13 & 14), exec("def x = (char)13; x &= 14; return x;"));
+        // int
+        assertEquals(13 & 14, exec("def x = 13; x &= 14; return x;"));
+        // long
+        assertEquals((long) (13 & 14), exec("def x = 13L; x &= 14; return x;"));
+    }
+    
+    public void testDefBogusCompoundAssignment() {
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("def x = 4F; int y = 1; x &= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("def x = 4D; int y = 1; x &= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; def y = 1F; x &= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; def y = 1D; x &= y");
+        });
+    }
 }

+ 0 - 314
modules/lang-painless/src/test/java/org/elasticsearch/painless/CompoundAssignmentTests.java

@@ -1,314 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.painless;
-
-/**
- * Tests compound assignments (+=, etc) across all data types
- */
-public class CompoundAssignmentTests extends ScriptTestCase {
-    public void testAddition() {
-        // byte
-        assertEquals((byte) 15, exec("byte x = 5; x += 10; return x;"));
-        assertEquals((byte) -5, exec("byte x = 5; x += -10; return x;"));
-
-        // short
-        assertEquals((short) 15, exec("short x = 5; x += 10; return x;"));
-        assertEquals((short) -5, exec("short x = 5; x += -10; return x;"));
-        // char
-        assertEquals((char) 15, exec("char x = 5; x += 10; return x;"));
-        assertEquals((char) 5, exec("char x = 10; x += -5; return x;"));
-        // int
-        assertEquals(15, exec("int x = 5; x += 10; return x;"));
-        assertEquals(-5, exec("int x = 5; x += -10; return x;"));
-        // long
-        assertEquals(15L, exec("long x = 5; x += 10; return x;"));
-        assertEquals(-5L, exec("long x = 5; x += -10; return x;"));
-        // float
-        assertEquals(15F, exec("float x = 5f; x += 10; return x;"));
-        assertEquals(-5F, exec("float x = 5f; x += -10; return x;"));
-        // double
-        assertEquals(15D, exec("double x = 5.0; x += 10; return x;"));
-        assertEquals(-5D, exec("double x = 5.0; x += -10; return x;"));
-    }
-
-    public void testSubtraction() {
-        // byte
-        assertEquals((byte) 15, exec("byte x = 5; x -= -10; return x;"));
-        assertEquals((byte) -5, exec("byte x = 5; x -= 10; return x;"));
-        // short
-        assertEquals((short) 15, exec("short x = 5; x -= -10; return x;"));
-        assertEquals((short) -5, exec("short x = 5; x -= 10; return x;"));
-        // char
-        assertEquals((char) 15, exec("char x = 5; x -= -10; return x;"));
-        assertEquals((char) 5, exec("char x = 10; x -= 5; return x;"));
-        // int
-        assertEquals(15, exec("int x = 5; x -= -10; return x;"));
-        assertEquals(-5, exec("int x = 5; x -= 10; return x;"));
-        // long
-        assertEquals(15L, exec("long x = 5; x -= -10; return x;"));
-        assertEquals(-5L, exec("long x = 5; x -= 10; return x;"));
-        // float
-        assertEquals(15F, exec("float x = 5f; x -= -10; return x;"));
-        assertEquals(-5F, exec("float x = 5f; x -= 10; return x;"));
-        // double
-        assertEquals(15D, exec("double x = 5.0; x -= -10; return x;"));
-        assertEquals(-5D, exec("double x = 5.0; x -= 10; return x;"));
-    }
-
-    public void testMultiplication() {
-        // byte
-        assertEquals((byte) 15, exec("byte x = 5; x *= 3; return x;"));
-        assertEquals((byte) -5, exec("byte x = 5; x *= -1; return x;"));
-        // short
-        assertEquals((short) 15, exec("short x = 5; x *= 3; return x;"));
-        assertEquals((short) -5, exec("short x = 5; x *= -1; return x;"));
-        // char
-        assertEquals((char) 15, exec("char x = 5; x *= 3; return x;"));
-        // int
-        assertEquals(15, exec("int x = 5; x *= 3; return x;"));
-        assertEquals(-5, exec("int x = 5; x *= -1; return x;"));
-        // long
-        assertEquals(15L, exec("long x = 5; x *= 3; return x;"));
-        assertEquals(-5L, exec("long x = 5; x *= -1; return x;"));
-        // float
-        assertEquals(15F, exec("float x = 5f; x *= 3; return x;"));
-        assertEquals(-5F, exec("float x = 5f; x *= -1; return x;"));
-        // double
-        assertEquals(15D, exec("double x = 5.0; x *= 3; return x;"));
-        assertEquals(-5D, exec("double x = 5.0; x *= -1; return x;"));
-    }
-
-    public void testDivision() {
-        // byte
-        assertEquals((byte) 15, exec("byte x = 45; x /= 3; return x;"));
-        assertEquals((byte) -5, exec("byte x = 5; x /= -1; return x;"));
-        // short
-        assertEquals((short) 15, exec("short x = 45; x /= 3; return x;"));
-        assertEquals((short) -5, exec("short x = 5; x /= -1; return x;"));
-        // char
-        assertEquals((char) 15, exec("char x = 45; x /= 3; return x;"));
-        // int
-        assertEquals(15, exec("int x = 45; x /= 3; return x;"));
-        assertEquals(-5, exec("int x = 5; x /= -1; return x;"));
-        // long
-        assertEquals(15L, exec("long x = 45; x /= 3; return x;"));
-        assertEquals(-5L, exec("long x = 5; x /= -1; return x;"));
-        // float
-        assertEquals(15F, exec("float x = 45f; x /= 3; return x;"));
-        assertEquals(-5F, exec("float x = 5f; x /= -1; return x;"));
-        // double
-        assertEquals(15D, exec("double x = 45.0; x /= 3; return x;"));
-        assertEquals(-5D, exec("double x = 5.0; x /= -1; return x;"));
-    }
-
-    public void testDivisionByZero() {
-        // byte
-        expectScriptThrows(ArithmeticException.class, () -> {
-            exec("byte x = 1; x /= 0; return x;");
-        });
-
-        // short
-        expectScriptThrows(ArithmeticException.class, () -> {
-            exec("short x = 1; x /= 0; return x;");
-        });
-
-        // char
-        expectScriptThrows(ArithmeticException.class, () -> {
-            exec("char x = 1; x /= 0; return x;");
-        });
-
-        // int
-        expectScriptThrows(ArithmeticException.class, () -> {
-            exec("int x = 1; x /= 0; return x;");
-        });
-
-        // long
-        expectScriptThrows(ArithmeticException.class, () -> {
-            exec("long x = 1; x /= 0; return x;");
-        });
-    }
-
-    public void testRemainder() {
-        // byte
-        assertEquals((byte) 3, exec("byte x = 15; x %= 4; return x;"));
-        assertEquals((byte) -3, exec("byte x = (byte) -15; x %= 4; return x;"));
-        // short
-        assertEquals((short) 3, exec("short x = 15; x %= 4; return x;"));
-        assertEquals((short) -3, exec("short x = (short) -15; x %= 4; return x;"));
-        // char
-        assertEquals((char) 3, exec("char x = (char) 15; x %= 4; return x;"));
-        // int
-        assertEquals(3, exec("int x = 15; x %= 4; return x;"));
-        assertEquals(-3, exec("int x = -15; x %= 4; return x;"));
-        // long
-        assertEquals(3L, exec("long x = 15L; x %= 4; return x;"));
-        assertEquals(-3L, exec("long x = -15L; x %= 4; return x;"));
-        // float
-        assertEquals(3F, exec("float x = 15F; x %= 4; return x;"));
-        assertEquals(-3F, exec("float x = -15F; x %= 4; return x;"));
-        // double
-        assertEquals(3D, exec("double x = 15.0; x %= 4; return x;"));
-        assertEquals(-3D, exec("double x = -15.0; x %= 4; return x;"));
-    }
-
-    public void testLeftShift() {
-        // byte
-        assertEquals((byte) 60, exec("byte x = 15; x <<= 2; return x;"));
-        assertEquals((byte) -60, exec("byte x = (byte) -15; x <<= 2; return x;"));
-        // short
-        assertEquals((short) 60, exec("short x = 15; x <<= 2; return x;"));
-        assertEquals((short) -60, exec("short x = (short) -15; x <<= 2; return x;"));
-        // char
-        assertEquals((char) 60, exec("char x = (char) 15; x <<= 2; return x;"));
-        // int
-        assertEquals(60, exec("int x = 15; x <<= 2; return x;"));
-        assertEquals(-60, exec("int x = -15; x <<= 2; return x;"));
-        // long
-        assertEquals(60L, exec("long x = 15L; x <<= 2; return x;"));
-        assertEquals(-60L, exec("long x = -15L; x <<= 2; return x;"));
-    }
-
-    public void testRightShift() {
-        // byte
-        assertEquals((byte) 15, exec("byte x = 60; x >>= 2; return x;"));
-        assertEquals((byte) -15, exec("byte x = (byte) -60; x >>= 2; return x;"));
-        // short
-        assertEquals((short) 15, exec("short x = 60; x >>= 2; return x;"));
-        assertEquals((short) -15, exec("short x = (short) -60; x >>= 2; return x;"));
-        // char
-        assertEquals((char) 15, exec("char x = (char) 60; x >>= 2; return x;"));
-        // int
-        assertEquals(15, exec("int x = 60; x >>= 2; return x;"));
-        assertEquals(-15, exec("int x = -60; x >>= 2; return x;"));
-        // long
-        assertEquals(15L, exec("long x = 60L; x >>= 2; return x;"));
-        assertEquals(-15L, exec("long x = -60L; x >>= 2; return x;"));
-    }
-
-    public void testUnsignedRightShift() {
-        // byte
-        assertEquals((byte) 15, exec("byte x = 60; x >>>= 2; return x;"));
-        assertEquals((byte) -15, exec("byte x = (byte) -60; x >>>= 2; return x;"));
-        // short
-        assertEquals((short) 15, exec("short x = 60; x >>>= 2; return x;"));
-        assertEquals((short) -15, exec("short x = (short) -60; x >>>= 2; return x;"));
-        // char
-        assertEquals((char) 15, exec("char x = (char) 60; x >>>= 2; return x;"));
-        // int
-        assertEquals(15, exec("int x = 60; x >>>= 2; return x;"));
-        assertEquals(-60 >>> 2, exec("int x = -60; x >>>= 2; return x;"));
-        // long
-        assertEquals(15L, exec("long x = 60L; x >>>= 2; return x;"));
-        assertEquals(-60L >>> 2, exec("long x = -60L; x >>>= 2; return x;"));
-    }
-
-    public void testAnd() {
-        // boolean
-        assertEquals(true, exec("boolean x = true; x &= true; return x;"));
-        assertEquals(false, exec("boolean x = true; x &= false; return x;"));
-        assertEquals(false, exec("boolean x = false; x &= true; return x;"));
-        assertEquals(false, exec("boolean x = false; x &= false; return x;"));
-        assertEquals(true, exec("def x = true; x &= true; return x;"));
-        assertEquals(false, exec("def x = true; x &= false; return x;"));
-        assertEquals(false, exec("def x = false; x &= true; return x;"));
-        assertEquals(false, exec("def x = false; x &= false; return x;"));
-        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] &= true; return x[0];"));
-        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] &= false; return x[0];"));
-        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] &= true; return x[0];"));
-        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] &= false; return x[0];"));
-        assertEquals(true, exec("def[] x = new def[1]; x[0] = true; x[0] &= true; return x[0];"));
-        assertEquals(false, exec("def[] x = new def[1]; x[0] = true; x[0] &= false; return x[0];"));
-        assertEquals(false, exec("def[] x = new def[1]; x[0] = false; x[0] &= true; return x[0];"));
-        assertEquals(false, exec("def[] x = new def[1]; x[0] = false; x[0] &= false; return x[0];"));
-
-        // byte
-        assertEquals((byte) (13 & 14), exec("byte x = 13; x &= 14; return x;"));
-        // short
-        assertEquals((short) (13 & 14), exec("short x = 13; x &= 14; return x;"));
-        // char
-        assertEquals((char) (13 & 14), exec("char x = 13; x &= 14; return x;"));
-        // int
-        assertEquals(13 & 14, exec("int x = 13; x &= 14; return x;"));
-        // long
-        assertEquals((long) (13 & 14), exec("long x = 13L; x &= 14; return x;"));
-    }
-
-    public void testOr() {
-        // boolean
-        assertEquals(true, exec("boolean x = true; x |= true; return x;"));
-        assertEquals(true, exec("boolean x = true; x |= false; return x;"));
-        assertEquals(true, exec("boolean x = false; x |= true; return x;"));
-        assertEquals(false, exec("boolean x = false; x |= false; return x;"));
-        assertEquals(true, exec("def x = true; x |= true; return x;"));
-        assertEquals(true, exec("def x = true; x |= false; return x;"));
-        assertEquals(true, exec("def x = false; x |= true; return x;"));
-        assertEquals(false, exec("def x = false; x |= false; return x;"));
-        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] |= true; return x[0];"));
-        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] |= false; return x[0];"));
-        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] |= true; return x[0];"));
-        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] |= false; return x[0];"));
-        assertEquals(true, exec("def[] x = new def[1]; x[0] = true; x[0] |= true; return x[0];"));
-        assertEquals(true, exec("def[] x = new def[1]; x[0] = true; x[0] |= false; return x[0];"));
-        assertEquals(true, exec("def[] x = new def[1]; x[0] = false; x[0] |= true; return x[0];"));
-        assertEquals(false, exec("def[] x = new def[1]; x[0] = false; x[0] |= false; return x[0];"));
-
-        // byte
-        assertEquals((byte) (13 | 14), exec("byte x = 13; x |= 14; return x;"));
-        // short
-        assertEquals((short) (13 | 14), exec("short x = 13; x |= 14; return x;"));
-        // char
-        assertEquals((char) (13 | 14), exec("char x = 13; x |= 14; return x;"));
-        // int
-        assertEquals(13 | 14, exec("int x = 13; x |= 14; return x;"));
-        // long
-        assertEquals((long) (13 | 14), exec("long x = 13L; x |= 14; return x;"));
-    }
-
-    public void testXor() {
-        // boolean
-        assertEquals(false, exec("boolean x = true; x ^= true; return x;"));
-        assertEquals(true, exec("boolean x = true; x ^= false; return x;"));
-        assertEquals(true, exec("boolean x = false; x ^= true; return x;"));
-        assertEquals(false, exec("boolean x = false; x ^= false; return x;"));
-        assertEquals(false, exec("def x = true; x ^= true; return x;"));
-        assertEquals(true, exec("def x = true; x ^= false; return x;"));
-        assertEquals(true, exec("def x = false; x ^= true; return x;"));
-        assertEquals(false, exec("def x = false; x ^= false; return x;"));
-        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] ^= true; return x[0];"));
-        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] ^= false; return x[0];"));
-        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] ^= true; return x[0];"));
-        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] ^= false; return x[0];"));
-        assertEquals(false, exec("def[] x = new def[1]; x[0] = true; x[0] ^= true; return x[0];"));
-        assertEquals(true, exec("def[] x = new def[1]; x[0] = true; x[0] ^= false; return x[0];"));
-        assertEquals(true, exec("def[] x = new def[1]; x[0] = false; x[0] ^= true; return x[0];"));
-        assertEquals(false, exec("def[] x = new def[1]; x[0] = false; x[0] ^= false; return x[0];"));
-
-        // byte
-        assertEquals((byte) (13 ^ 14), exec("byte x = 13; x ^= 14; return x;"));
-        // short
-        assertEquals((short) (13 ^ 14), exec("short x = 13; x ^= 14; return x;"));
-        // char
-        assertEquals((char) (13 ^ 14), exec("char x = 13; x ^= 14; return x;"));
-        // int
-        assertEquals(13 ^ 14, exec("int x = 13; x ^= 14; return x;"));
-        // long
-        assertEquals((long) (13 ^ 14), exec("long x = 13L; x ^= 14; return x;"));
-    }
-}

+ 83 - 15
modules/lang-painless/src/test/java/org/elasticsearch/painless/DefBootstrapTests.java

@@ -40,11 +40,11 @@ public class DefBootstrapTests extends ESTestCase {
         assertDepthEquals(site, 0);
 
         // invoke with integer, needs lookup
-        assertEquals("5", handle.invoke(Integer.valueOf(5)));
+        assertEquals("5", (String)handle.invokeExact((Object)5));
         assertDepthEquals(site, 1);
 
         // invoked with integer again: should be cached
-        assertEquals("6", handle.invoke(Integer.valueOf(6)));
+        assertEquals("6", (String)handle.invokeExact((Object)6));
         assertDepthEquals(site, 1);
     }
     
@@ -56,15 +56,15 @@ public class DefBootstrapTests extends ESTestCase {
         MethodHandle handle = site.dynamicInvoker();
         assertDepthEquals(site, 0);
 
-        assertEquals("5", handle.invoke(Integer.valueOf(5)));
+        assertEquals("5", (String)handle.invokeExact((Object)5));
         assertDepthEquals(site, 1);
-        assertEquals("1.5", handle.invoke(Float.valueOf(1.5f)));
+        assertEquals("1.5", (String)handle.invokeExact((Object)1.5f));
         assertDepthEquals(site, 2);
 
         // both these should be cached
-        assertEquals("6", handle.invoke(Integer.valueOf(6)));
+        assertEquals("6", (String)handle.invokeExact((Object)6));
         assertDepthEquals(site, 2);
-        assertEquals("2.5", handle.invoke(Float.valueOf(2.5f)));
+        assertEquals("2.5", (String)handle.invokeExact((Object)2.5f));
         assertDepthEquals(site, 2);
     }
     
@@ -78,17 +78,17 @@ public class DefBootstrapTests extends ESTestCase {
         MethodHandle handle = site.dynamicInvoker();
         assertDepthEquals(site, 0);
 
-        assertEquals("5", handle.invoke(Integer.valueOf(5)));
+        assertEquals("5", (String)handle.invokeExact((Object)5));
         assertDepthEquals(site, 1);
-        assertEquals("1.5", handle.invoke(Float.valueOf(1.5f)));
+        assertEquals("1.5", (String)handle.invokeExact((Object)1.5f));
         assertDepthEquals(site, 2);
-        assertEquals("6", handle.invoke(Long.valueOf(6)));
+        assertEquals("6", (String)handle.invokeExact((Object)6L));
         assertDepthEquals(site, 3);
-        assertEquals("3.2", handle.invoke(Double.valueOf(3.2d)));
+        assertEquals("3.2", (String)handle.invokeExact((Object)3.2d));
         assertDepthEquals(site, 4);
-        assertEquals("foo", handle.invoke(new String("foo")));
+        assertEquals("foo", (String)handle.invokeExact((Object)"foo"));
         assertDepthEquals(site, 5);
-        assertEquals("c", handle.invoke(Character.valueOf('c')));
+        assertEquals("c", (String)handle.invokeExact((Object)'c'));
         assertDepthEquals(site, 5);
     }
     
@@ -100,9 +100,77 @@ public class DefBootstrapTests extends ESTestCase {
                                                                           DefBootstrap.METHOD_CALL, 0L);
         site.depth = DefBootstrap.PIC.MAX_DEPTH; // mark megamorphic
         MethodHandle handle = site.dynamicInvoker();
-        // arguments are cast to object here, or IDE compilers eat it :)
-        assertEquals(2, handle.invoke((Object) Arrays.asList("1", "2")));
-        assertEquals(1, handle.invoke((Object) Collections.singletonMap("a", "b")));
+        assertEquals(2, (int)handle.invokeExact((Object) Arrays.asList("1", "2")));
+        assertEquals(1, (int)handle.invokeExact((Object) Collections.singletonMap("a", "b")));
+    }
+    
+    // test operators with null guards
+
+    public void testNullGuardAdd() throws Throwable {
+        DefBootstrap.MIC site = (DefBootstrap.MIC) DefBootstrap.bootstrap(MethodHandles.publicLookup(), 
+                                                               "add", 
+                                                               MethodType.methodType(Object.class, Object.class, Object.class),
+                                                               DefBootstrap.BINARY_OPERATOR, DefBootstrap.OPERATOR_ALLOWS_NULL);
+        MethodHandle handle = site.dynamicInvoker();
+        assertEquals("nulltest", (Object)handle.invokeExact((Object)null, (Object)"test"));
+    }
+    
+    public void testNullGuardAddWhenCached() throws Throwable {
+        DefBootstrap.MIC site = (DefBootstrap.MIC) DefBootstrap.bootstrap(MethodHandles.publicLookup(), 
+                                                               "add", 
+                                                               MethodType.methodType(Object.class, Object.class, Object.class),
+                                                               DefBootstrap.BINARY_OPERATOR, DefBootstrap.OPERATOR_ALLOWS_NULL);
+        MethodHandle handle = site.dynamicInvoker();
+        assertEquals(2, (Object)handle.invokeExact((Object)1, (Object)1));
+        assertEquals("nulltest", (Object)handle.invoke((Object)null, (Object)"test"));
+    }
+    
+    public void testNullGuardEq() throws Throwable {
+        DefBootstrap.MIC site = (DefBootstrap.MIC) DefBootstrap.bootstrap(MethodHandles.publicLookup(), 
+                                                               "eq", 
+                                                               MethodType.methodType(boolean.class, Object.class, Object.class),
+                                                               DefBootstrap.BINARY_OPERATOR, DefBootstrap.OPERATOR_ALLOWS_NULL);
+        MethodHandle handle = site.dynamicInvoker();
+        assertFalse((boolean) handle.invokeExact((Object)null, (Object)"test"));
+        assertTrue((boolean) handle.invokeExact((Object)null, (Object)null));
+    }
+    
+    public void testNullGuardEqWhenCached() throws Throwable {
+        DefBootstrap.MIC site = (DefBootstrap.MIC) DefBootstrap.bootstrap(MethodHandles.publicLookup(), 
+                                                               "eq", 
+                                                               MethodType.methodType(boolean.class, Object.class, Object.class),
+                                                               DefBootstrap.BINARY_OPERATOR, DefBootstrap.OPERATOR_ALLOWS_NULL);
+        MethodHandle handle = site.dynamicInvoker();
+        assertTrue((boolean) handle.invokeExact((Object)1, (Object)1));
+        assertFalse((boolean) handle.invokeExact((Object)null, (Object)"test"));
+        assertTrue((boolean) handle.invokeExact((Object)null, (Object)null));
+    }
+    
+    // make sure these operators work without null guards too
+    // for example, nulls are only legal for + if the other parameter is a String,
+    // and can be disabled in some circumstances.
+    
+    public void testNoNullGuardAdd() throws Throwable {
+        DefBootstrap.MIC site = (DefBootstrap.MIC) DefBootstrap.bootstrap(MethodHandles.publicLookup(), 
+                                                               "add", 
+                                                               MethodType.methodType(Object.class, int.class, Object.class),
+                                                               DefBootstrap.BINARY_OPERATOR, 0);
+        MethodHandle handle = site.dynamicInvoker();
+        expectThrows(NullPointerException.class, () -> {
+            assertNotNull((Object)handle.invokeExact(5, (Object)null));
+        });
+    }
+    
+    public void testNoNullGuardAddWhenCached() throws Throwable {
+        DefBootstrap.MIC site = (DefBootstrap.MIC) DefBootstrap.bootstrap(MethodHandles.publicLookup(), 
+                                                               "add", 
+                                                               MethodType.methodType(Object.class, int.class, Object.class),
+                                                               DefBootstrap.BINARY_OPERATOR, 0);
+        MethodHandle handle = site.dynamicInvoker();
+        assertEquals(2, (Object)handle.invokeExact(1, (Object)1));
+        expectThrows(NullPointerException.class, () -> {
+            assertNotNull((Object)handle.invokeExact(5, (Object)null));
+        });
     }
     
     static void assertDepthEquals(CallSite site, int expected) {

+ 34 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/DefOptimizationTests.java

@@ -239,6 +239,25 @@ public class DefOptimizationTests extends ScriptTestCase {
                              "INVOKEDYNAMIC add(Ljava/lang/Object;Ljava/lang/Object;)D");
     }
     
+    // horrible, sorry
+    public void testAddOptNullGuards() {
+        // needs null guard
+        assertBytecodeHasPattern("def x = 1; def y = 2; return x + y", 
+                "(?s).*INVOKEDYNAMIC add.*arguments:\\s+" + DefBootstrap.BINARY_OPERATOR 
+                                                + ",\\s+" + DefBootstrap.OPERATOR_ALLOWS_NULL + ".*");
+        // still needs null guard, NPE is the wrong thing!
+        assertBytecodeHasPattern("def x = 1; def y = 2; double z = x + y", 
+                "(?s).*INVOKEDYNAMIC add.*arguments:\\s+" + DefBootstrap.BINARY_OPERATOR 
+                                                + ",\\s+" + DefBootstrap.OPERATOR_ALLOWS_NULL + ".*");
+        // a primitive argument is present: no null guard needed
+        assertBytecodeHasPattern("def x = 1; int y = 2; return x + y", 
+                "(?s).*INVOKEDYNAMIC add.*arguments:\\s+" + DefBootstrap.BINARY_OPERATOR 
+                                                + ",\\s+" + 0 + ".*");
+        assertBytecodeHasPattern("int x = 1; def y = 2; return x + y", 
+                "(?s).*INVOKEDYNAMIC add.*arguments:\\s+" + DefBootstrap.BINARY_OPERATOR 
+                                                + ",\\s+" + 0 + ".*");
+    }
+    
     public void testSubOptLHS() {
         assertBytecodeExists("int x = 1; def y = 2; return x - y", 
                              "INVOKEDYNAMIC sub(ILjava/lang/Object;)Ljava/lang/Object;");
@@ -343,7 +362,22 @@ public class DefOptimizationTests extends ScriptTestCase {
         assertBytecodeExists("def x = 1; def y = 2; double d = x ^ y", 
                              "INVOKEDYNAMIC xor(Ljava/lang/Object;Ljava/lang/Object;)D");
     }
+
+    public void testBooleanXorOptLHS() {
+        assertBytecodeExists("boolean x = true; def y = true; return x ^ y", 
+                "INVOKEDYNAMIC xor(ZLjava/lang/Object;)Ljava/lang/Object;");
+    }
+
+    public void testBooleanXorOptRHS() {
+        assertBytecodeExists("def x = true; boolean y = true; return x ^ y", 
+                "INVOKEDYNAMIC xor(Ljava/lang/Object;Z)Ljava/lang/Object;");
+    }
     
+    public void testBooleanXorOptRet() {
+        assertBytecodeExists("def x = true; def y = true; boolean v = x ^ y", 
+                "INVOKEDYNAMIC xor(Ljava/lang/Object;Ljava/lang/Object;)Z");
+    }
+
     public void testLtOptLHS() {
         assertBytecodeExists("int x = 1; def y = 2; return x < y", 
                              "INVOKEDYNAMIC lt(ILjava/lang/Object;)Z");

+ 78 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/DivisionTests.java

@@ -335,4 +335,82 @@ public class DivisionTests extends ScriptTestCase {
         assertEquals(1F, exec("def x = (float)2; float y = (float)2; return x / y"));
         assertEquals(1D, exec("def x = (double)2; double y = (double)2; return x / y"));
     }
+    
+    public void testCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("byte x = 45; x /= 3; return x;"));
+        assertEquals((byte) -5, exec("byte x = 5; x /= -1; return x;"));
+        // short
+        assertEquals((short) 15, exec("short x = 45; x /= 3; return x;"));
+        assertEquals((short) -5, exec("short x = 5; x /= -1; return x;"));
+        // char
+        assertEquals((char) 15, exec("char x = 45; x /= 3; return x;"));
+        // int
+        assertEquals(15, exec("int x = 45; x /= 3; return x;"));
+        assertEquals(-5, exec("int x = 5; x /= -1; return x;"));
+        // long
+        assertEquals(15L, exec("long x = 45; x /= 3; return x;"));
+        assertEquals(-5L, exec("long x = 5; x /= -1; return x;"));
+        // float
+        assertEquals(15F, exec("float x = 45f; x /= 3; return x;"));
+        assertEquals(-5F, exec("float x = 5f; x /= -1; return x;"));
+        // double
+        assertEquals(15D, exec("double x = 45.0; x /= 3; return x;"));
+        assertEquals(-5D, exec("double x = 5.0; x /= -1; return x;"));
+    }
+    
+    public void testDefCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("def x = (byte)45; x /= 3; return x;"));
+        assertEquals((byte) -5, exec("def x = (byte)5; x /= -1; return x;"));
+        // short
+        assertEquals((short) 15, exec("def x = (short)45; x /= 3; return x;"));
+        assertEquals((short) -5, exec("def x = (short)5; x /= -1; return x;"));
+        // char
+        assertEquals((char) 15, exec("def x = (char)45; x /= 3; return x;"));
+        // int
+        assertEquals(15, exec("def x = 45; x /= 3; return x;"));
+        assertEquals(-5, exec("def x = 5; x /= -1; return x;"));
+        // long
+        assertEquals(15L, exec("def x = 45L; x /= 3; return x;"));
+        assertEquals(-5L, exec("def x = 5L; x /= -1; return x;"));
+        // float
+        assertEquals(15F, exec("def x = 45f; x /= 3; return x;"));
+        assertEquals(-5F, exec("def x = 5f; x /= -1; return x;"));
+        // double
+        assertEquals(15D, exec("def x = 45.0; x /= 3; return x;"));
+        assertEquals(-5D, exec("def x = 5.0; x /= -1; return x;"));
+    }
+    
+    public void testCompoundAssignmentByZero() {
+        // byte
+        expectScriptThrows(ArithmeticException.class, () -> {
+            exec("byte x = 1; x /= 0; return x;");
+        });
+
+        // short
+        expectScriptThrows(ArithmeticException.class, () -> {
+            exec("short x = 1; x /= 0; return x;");
+        });
+
+        // char
+        expectScriptThrows(ArithmeticException.class, () -> {
+            exec("char x = 1; x /= 0; return x;");
+        });
+
+        // int
+        expectScriptThrows(ArithmeticException.class, () -> {
+            exec("int x = 1; x /= 0; return x;");
+        });
+
+        // long
+        expectScriptThrows(ArithmeticException.class, () -> {
+            exec("long x = 1; x /= 0; return x;");
+        });
+        
+        // def
+        expectScriptThrows(ArithmeticException.class, () -> {
+            exec("def x = 1; x /= 0; return x;");
+        });
+    }
 }

+ 31 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/IncrementTests.java

@@ -76,4 +76,35 @@ public class IncrementTests extends ScriptTestCase {
         assertEquals(1D, exec("double x = 0.0; return ++x;"));
         assertEquals(-1D, exec("double x = 0.0; return --x;"));
     }
+    
+    /** incrementing def values */
+    public void testIncrementDef() {
+        assertEquals((byte)0, exec("def x = (byte)0; return x++;"));
+        assertEquals((byte)0, exec("def x = (byte)0; return x--;"));
+        assertEquals((byte)1, exec("def x = (byte)0; return ++x;"));
+        assertEquals((byte)-1, exec("def x = (byte)0; return --x;"));
+        assertEquals((char)0, exec("def x = (char)0; return x++;"));
+        assertEquals((char)1, exec("def x = (char)1; return x--;"));
+        assertEquals((char)1, exec("def x = (char)0; return ++x;"));
+        assertEquals((short)0, exec("def x = (short)0; return x++;"));
+        assertEquals((short)0, exec("def x = (short)0; return x--;"));
+        assertEquals((short)1, exec("def x = (short)0; return ++x;"));
+        assertEquals((short)-1, exec("def x = (short)0; return --x;"));
+        assertEquals(0, exec("def x = 0; return x++;"));
+        assertEquals(0, exec("def x = 0; return x--;"));
+        assertEquals(1, exec("def x = 0; return ++x;"));
+        assertEquals(-1, exec("def x = 0; return --x;"));
+        assertEquals(0L, exec("def x = 0L; return x++;"));
+        assertEquals(0L, exec("def x = 0L; return x--;"));
+        assertEquals(1L, exec("def x = 0L; return ++x;"));
+        assertEquals(-1L, exec("def x = 0L; return --x;"));
+        assertEquals(0F, exec("def x = 0F; return x++;"));
+        assertEquals(0F, exec("def x = 0F; return x--;"));
+        assertEquals(1F, exec("def x = 0F; return ++x;"));
+        assertEquals(-1F, exec("def x = 0F; return --x;"));
+        assertEquals(0D, exec("def x = 0.0; return x++;"));
+        assertEquals(0D, exec("def x = 0.0; return x--;"));
+        assertEquals(1D, exec("def x = 0.0; return ++x;"));
+        assertEquals(-1D, exec("def x = 0.0; return --x;"));
+    }
 }

+ 46 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/MultiplicationTests.java

@@ -325,4 +325,50 @@ public class MultiplicationTests extends ScriptTestCase {
         assertEquals(4F, exec("def x = (float)2; float y = (float)2; return x * y"));
         assertEquals(4D, exec("def x = (double)2; double y = (double)2; return x * y"));
     }
+    
+    public void testCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("byte x = 5; x *= 3; return x;"));
+        assertEquals((byte) -5, exec("byte x = 5; x *= -1; return x;"));
+        // short
+        assertEquals((short) 15, exec("short x = 5; x *= 3; return x;"));
+        assertEquals((short) -5, exec("short x = 5; x *= -1; return x;"));
+        // char
+        assertEquals((char) 15, exec("char x = 5; x *= 3; return x;"));
+        // int
+        assertEquals(15, exec("int x = 5; x *= 3; return x;"));
+        assertEquals(-5, exec("int x = 5; x *= -1; return x;"));
+        // long
+        assertEquals(15L, exec("long x = 5; x *= 3; return x;"));
+        assertEquals(-5L, exec("long x = 5; x *= -1; return x;"));
+        // float
+        assertEquals(15F, exec("float x = 5f; x *= 3; return x;"));
+        assertEquals(-5F, exec("float x = 5f; x *= -1; return x;"));
+        // double
+        assertEquals(15D, exec("double x = 5.0; x *= 3; return x;"));
+        assertEquals(-5D, exec("double x = 5.0; x *= -1; return x;"));
+    }
+    
+    public void testDefCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("def x = (byte)5; x *= 3; return x;"));
+        assertEquals((byte) -5, exec("def x = (byte)5; x *= -1; return x;"));
+        // short
+        assertEquals((short) 15, exec("def x = (short)5; x *= 3; return x;"));
+        assertEquals((short) -5, exec("def x = (short)5; x *= -1; return x;"));
+        // char
+        assertEquals((char) 15, exec("def x = (char)5; x *= 3; return x;"));
+        // int
+        assertEquals(15, exec("def x = 5; x *= 3; return x;"));
+        assertEquals(-5, exec("def x = 5; x *= -1; return x;"));
+        // long
+        assertEquals(15L, exec("def x = 5L; x *= 3; return x;"));
+        assertEquals(-5L, exec("def x = 5L; x *= -1; return x;"));
+        // float
+        assertEquals(15F, exec("def x = 5f; x *= 3; return x;"));
+        assertEquals(-5F, exec("def x = 5f; x *= -1; return x;"));
+        // double
+        assertEquals(15D, exec("def x = 5.0; x *= 3; return x;"));
+        assertEquals(-5D, exec("def x = 5.0; x *= -1; return x;"));
+    }
 }

+ 76 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/OrTests.java

@@ -208,4 +208,80 @@ public class OrTests extends ScriptTestCase {
         assertEquals(true,  exec("def x = false; boolean y = true; return x | y"));
         assertEquals(false, exec("def x = false; boolean y = false; return x | y"));
     }
+    
+    public void testCompoundAssignment() {
+        // boolean
+        assertEquals(true, exec("boolean x = true; x |= true; return x;"));
+        assertEquals(true, exec("boolean x = true; x |= false; return x;"));
+        assertEquals(true, exec("boolean x = false; x |= true; return x;"));
+        assertEquals(false, exec("boolean x = false; x |= false; return x;"));
+        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] |= true; return x[0];"));
+        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] |= false; return x[0];"));
+        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] |= true; return x[0];"));
+        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] |= false; return x[0];"));
+
+        // byte
+        assertEquals((byte) (13 | 14), exec("byte x = 13; x |= 14; return x;"));
+        // short
+        assertEquals((short) (13 | 14), exec("short x = 13; x |= 14; return x;"));
+        // char
+        assertEquals((char) (13 | 14), exec("char x = 13; x |= 14; return x;"));
+        // int
+        assertEquals(13 | 14, exec("int x = 13; x |= 14; return x;"));
+        // long
+        assertEquals((long) (13 | 14), exec("long x = 13L; x |= 14; return x;"));
+    }
+    
+    public void testBogusCompoundAssignment() {
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("float x = 4; int y = 1; x |= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("double x = 4; int y = 1; x |= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; float y = 1; x |= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; double y = 1; x |= y");
+        });
+    }
+    
+    public void testDefCompoundAssignment() {
+        // boolean
+        assertEquals(true, exec("def x = true; x |= true; return x;"));
+        assertEquals(true, exec("def x = true; x |= false; return x;"));
+        assertEquals(true, exec("def x = false; x |= true; return x;"));
+        assertEquals(false, exec("def x = false; x |= false; return x;"));
+        assertEquals(true, exec("def[] x = new def[1]; x[0] = true; x[0] |= true; return x[0];"));
+        assertEquals(true, exec("def[] x = new def[1]; x[0] = true; x[0] |= false; return x[0];"));
+        assertEquals(true, exec("def[] x = new def[1]; x[0] = false; x[0] |= true; return x[0];"));
+        assertEquals(false, exec("def[] x = new def[1]; x[0] = false; x[0] |= false; return x[0];"));
+
+        // byte
+        assertEquals((byte) (13 | 14), exec("def x = (byte)13; x |= 14; return x;"));
+        // short
+        assertEquals((short) (13 | 14), exec("def x = (short)13; x |= 14; return x;"));
+        // char
+        assertEquals((char) (13 | 14), exec("def x = (char)13; x |= 14; return x;"));
+        // int
+        assertEquals(13 | 14, exec("def x = 13; x |= 14; return x;"));
+        // long
+        assertEquals((long) (13 | 14), exec("def x = 13L; x |= 14; return x;"));
+    }
+    
+    public void testDefBogusCompoundAssignment() {
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("def x = 4F; int y = 1; x |= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("def x = 4D; int y = 1; x |= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("def x = 4; float y = 1; x |= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("def x = 4; double y = 1; x |= y");
+        });
+    }
 }

+ 46 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/RemainderTests.java

@@ -335,4 +335,50 @@ public class RemainderTests extends ScriptTestCase {
         assertEquals(0F, exec("def x = (float)2; float y = (float)2; return x % y"));
         assertEquals(0D, exec("def x = (double)2; double y = (double)2; return x % y"));
     }
+    
+    public void testCompoundAssignment() {
+        // byte
+        assertEquals((byte) 3, exec("byte x = 15; x %= 4; return x;"));
+        assertEquals((byte) -3, exec("byte x = (byte) -15; x %= 4; return x;"));
+        // short
+        assertEquals((short) 3, exec("short x = 15; x %= 4; return x;"));
+        assertEquals((short) -3, exec("short x = (short) -15; x %= 4; return x;"));
+        // char
+        assertEquals((char) 3, exec("char x = (char) 15; x %= 4; return x;"));
+        // int
+        assertEquals(3, exec("int x = 15; x %= 4; return x;"));
+        assertEquals(-3, exec("int x = -15; x %= 4; return x;"));
+        // long
+        assertEquals(3L, exec("long x = 15L; x %= 4; return x;"));
+        assertEquals(-3L, exec("long x = -15L; x %= 4; return x;"));
+        // float
+        assertEquals(3F, exec("float x = 15F; x %= 4; return x;"));
+        assertEquals(-3F, exec("float x = -15F; x %= 4; return x;"));
+        // double
+        assertEquals(3D, exec("double x = 15.0; x %= 4; return x;"));
+        assertEquals(-3D, exec("double x = -15.0; x %= 4; return x;"));
+    }
+    
+    public void testDefCompoundAssignment() {
+        // byte
+        assertEquals((byte) 3, exec("def x = (byte)15; x %= 4; return x;"));
+        assertEquals((byte) -3, exec("def x = (byte) -15; x %= 4; return x;"));
+        // short
+        assertEquals((short) 3, exec("def x = (short)15; x %= 4; return x;"));
+        assertEquals((short) -3, exec("def x = (short) -15; x %= 4; return x;"));
+        // char
+        assertEquals((char) 3, exec("def x = (char) 15; x %= 4; return x;"));
+        // int
+        assertEquals(3, exec("def x = 15; x %= 4; return x;"));
+        assertEquals(-3, exec("def x = -15; x %= 4; return x;"));
+        // long
+        assertEquals(3L, exec("def x = 15L; x %= 4; return x;"));
+        assertEquals(-3L, exec("def x = -15L; x %= 4; return x;"));
+        // float
+        assertEquals(3F, exec("def x = 15F; x %= 4; return x;"));
+        assertEquals(-3F, exec("def x = -15F; x %= 4; return x;"));
+        // double
+        assertEquals(3D, exec("def x = 15.0; x %= 4; return x;"));
+        assertEquals(-3D, exec("def x = -15.0; x %= 4; return x;"));
+    }
 }

+ 9 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptTestCase.java

@@ -79,6 +79,15 @@ public abstract class ScriptTestCase extends ESTestCase {
         assertTrue("bytecode not found, got: \n" + asm , asm.contains(bytecode));
     }
     
+    /**
+     * Uses the {@link Debugger} to get the bytecode output for a script and compare
+     * it against an expected bytecode pattern as a regular expression (please try to avoid!)
+     */
+    public void assertBytecodeHasPattern(String script, String pattern) {
+        final String asm = Debugger.toString(script);
+        assertTrue("bytecode not found, got: \n" + asm , asm.matches(pattern));
+    }
+    
     /** Checks a specific exception class is thrown (boxed inside ScriptException) and returns it. */
     public static <T extends Throwable> T expectScriptThrows(Class<T> expectedType, ThrowingRunnable runnable) {
         try {

+ 104 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/ShiftTests.java

@@ -544,4 +544,108 @@ public class ShiftTests extends ScriptTestCase {
         });
     }
 
+    public void testLshCompoundAssignment() {
+        // byte
+        assertEquals((byte) 60, exec("byte x = 15; x <<= 2; return x;"));
+        assertEquals((byte) -60, exec("byte x = (byte) -15; x <<= 2; return x;"));
+        // short
+        assertEquals((short) 60, exec("short x = 15; x <<= 2; return x;"));
+        assertEquals((short) -60, exec("short x = (short) -15; x <<= 2; return x;"));
+        // char
+        assertEquals((char) 60, exec("char x = (char) 15; x <<= 2; return x;"));
+        // int
+        assertEquals(60, exec("int x = 15; x <<= 2; return x;"));
+        assertEquals(-60, exec("int x = -15; x <<= 2; return x;"));
+        // long
+        assertEquals(60L, exec("long x = 15L; x <<= 2; return x;"));
+        assertEquals(-60L, exec("long x = -15L; x <<= 2; return x;"));
+        // long shift distance
+        assertEquals(60, exec("int x = 15; x <<= 2L; return x;"));
+        assertEquals(-60, exec("int x = -15; x <<= 2L; return x;"));
+    }
+
+    public void testRshCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("byte x = 60; x >>= 2; return x;"));
+        assertEquals((byte) -15, exec("byte x = (byte) -60; x >>= 2; return x;"));
+        // short
+        assertEquals((short) 15, exec("short x = 60; x >>= 2; return x;"));
+        assertEquals((short) -15, exec("short x = (short) -60; x >>= 2; return x;"));
+        // char
+        assertEquals((char) 15, exec("char x = (char) 60; x >>= 2; return x;"));
+        // int
+        assertEquals(15, exec("int x = 60; x >>= 2; return x;"));
+        assertEquals(-15, exec("int x = -60; x >>= 2; return x;"));
+        // long
+        assertEquals(15L, exec("long x = 60L; x >>= 2; return x;"));
+        assertEquals(-15L, exec("long x = -60L; x >>= 2; return x;"));
+        // long shift distance
+        assertEquals(15, exec("int x = 60; x >>= 2L; return x;"));
+        assertEquals(-15, exec("int x = -60; x >>= 2L; return x;"));
+    }
+
+    public void testUshCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("byte x = 60; x >>>= 2; return x;"));
+        assertEquals((byte) -15, exec("byte x = (byte) -60; x >>>= 2; return x;"));
+        // short
+        assertEquals((short) 15, exec("short x = 60; x >>>= 2; return x;"));
+        assertEquals((short) -15, exec("short x = (short) -60; x >>>= 2; return x;"));
+        // char
+        assertEquals((char) 15, exec("char x = (char) 60; x >>>= 2; return x;"));
+        // int
+        assertEquals(15, exec("int x = 60; x >>>= 2; return x;"));
+        assertEquals(-60 >>> 2, exec("int x = -60; x >>>= 2; return x;"));
+        // long
+        assertEquals(15L, exec("long x = 60L; x >>>= 2; return x;"));
+        assertEquals(-60L >>> 2, exec("long x = -60L; x >>>= 2; return x;"));
+        // long shift distance
+        assertEquals(15, exec("int x = 60; x >>>= 2L; return x;"));
+        assertEquals(-60 >>> 2, exec("int x = -60; x >>>= 2L; return x;"));
+    }
+    
+    public void testBogusCompoundAssignment() {
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("long x = 1L; float y = 2; x <<= y;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("int x = 1; double y = 2L; x <<= y;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("float x = 1F; int y = 2; x <<= y;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("double x = 1D; int y = 2L; x <<= y;");
+        });
+    }
+
+    public void testBogusCompoundAssignmentConst() {
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("int x = 1L; x <<= 2F;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("int x = 1L; x <<= 2.0;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("float x = 1F; x <<= 2;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("double x = 1D; x <<= 2L;");
+        });
+    }
+    
+    public void testBogusCompoundAssignmentDef() {
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("def x = 1L; float y = 2; x <<= y;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("def x = 1; double y = 2L; x <<= y;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("float x = 1F; def y = 2; x <<= y;");
+        });
+        expectScriptThrows(ClassCastException.class, ()-> {
+            exec("double x = 1D; def y = 2L; x <<= y;");
+        });
+    }
 }

+ 16 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/StringTests.java

@@ -205,4 +205,20 @@ public class StringTests extends ScriptTestCase {
             exec("def x = null; def y = null; return x + y");
         });
     }
+    
+    public void testDefCompoundAssignment() {
+        assertEquals("a" + (byte)2, exec("def x = 'a'; x += (byte)2; return x"));
+        assertEquals("a" + (short)2, exec("def x = 'a'; x  += (short)2; return x"));
+        assertEquals("a" + (char)2, exec("def x = 'a'; x += (char)2; return x"));
+        assertEquals("a" + 2, exec("def x = 'a'; x += (int)2; return x"));
+        assertEquals("a" + 2L, exec("def x = 'a'; x += (long)2; return x"));
+        assertEquals("a" + 2F, exec("def x = 'a'; x += (float)2; return x"));
+        assertEquals("a" + 2D, exec("def x = 'a'; x += (double)2; return x"));
+        assertEquals("ab", exec("def x = 'a'; def y = 'b'; x += y; return x"));
+        assertEquals("anull", exec("def x = 'a'; x += null; return x"));
+        assertEquals("nullb", exec("def x = null; x += 'b'; return x"));
+        expectScriptThrows(NullPointerException.class, () -> {
+            exec("def x = null; def y = null; x += y");
+        });
+    }
 }

+ 48 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/SubtractionTests.java

@@ -355,4 +355,52 @@ public class SubtractionTests extends ScriptTestCase {
         assertEquals(0D, exec("def x = (float)1; double y = (double)1; return x - y"));
         assertEquals(0D, exec("def x = (double)1; double y = (double)1; return x - y"));
     }
+    
+    public void testCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("byte x = 5; x -= -10; return x;"));
+        assertEquals((byte) -5, exec("byte x = 5; x -= 10; return x;"));
+        // short
+        assertEquals((short) 15, exec("short x = 5; x -= -10; return x;"));
+        assertEquals((short) -5, exec("short x = 5; x -= 10; return x;"));
+        // char
+        assertEquals((char) 15, exec("char x = 5; x -= -10; return x;"));
+        assertEquals((char) 5, exec("char x = 10; x -= 5; return x;"));
+        // int
+        assertEquals(15, exec("int x = 5; x -= -10; return x;"));
+        assertEquals(-5, exec("int x = 5; x -= 10; return x;"));
+        // long
+        assertEquals(15L, exec("long x = 5; x -= -10; return x;"));
+        assertEquals(-5L, exec("long x = 5; x -= 10; return x;"));
+        // float
+        assertEquals(15F, exec("float x = 5f; x -= -10; return x;"));
+        assertEquals(-5F, exec("float x = 5f; x -= 10; return x;"));
+        // double
+        assertEquals(15D, exec("double x = 5.0; x -= -10; return x;"));
+        assertEquals(-5D, exec("double x = 5.0; x -= 10; return x;"));
+    }
+    
+    public void testDefCompoundAssignment() {
+        // byte
+        assertEquals((byte) 15, exec("def x = (byte)5; x -= -10; return x;"));
+        assertEquals((byte) -5, exec("def x = (byte)5; x -= 10; return x;"));
+        // short
+        assertEquals((short) 15, exec("def x = (short)5; x -= -10; return x;"));
+        assertEquals((short) -5, exec("def x = (short)5; x -= 10; return x;"));
+        // char
+        assertEquals((char) 15, exec("def x = (char)5; x -= -10; return x;"));
+        assertEquals((char) 5, exec("def x = (char)10; x -= 5; return x;"));
+        // int
+        assertEquals(15, exec("def x = 5; x -= -10; return x;"));
+        assertEquals(-5, exec("def x = 5; x -= 10; return x;"));
+        // long
+        assertEquals(15L, exec("def x = 5L; x -= -10; return x;"));
+        assertEquals(-5L, exec("def x = 5L; x -= 10; return x;"));
+        // float
+        assertEquals(15F, exec("def x = 5f; x -= -10; return x;"));
+        assertEquals(-5F, exec("def x = 5f; x -= 10; return x;"));
+        // double
+        assertEquals(15D, exec("def x = 5.0; x -= -10; return x;"));
+        assertEquals(-5D, exec("def x = 5.0; x -= 10; return x;"));
+    }
 }

+ 76 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/XorTests.java

@@ -222,4 +222,80 @@ public class XorTests extends ScriptTestCase {
         assertEquals(true,  exec("def x = false; boolean y = true; return x ^ y"));
         assertEquals(false, exec("def x = false; boolean y = false; return x ^ y"));
     }
+    
+    public void testCompoundAssignment() {
+        // boolean
+        assertEquals(false, exec("boolean x = true; x ^= true; return x;"));
+        assertEquals(true, exec("boolean x = true; x ^= false; return x;"));
+        assertEquals(true, exec("boolean x = false; x ^= true; return x;"));
+        assertEquals(false, exec("boolean x = false; x ^= false; return x;"));
+        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] ^= true; return x[0];"));
+        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = true; x[0] ^= false; return x[0];"));
+        assertEquals(true, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] ^= true; return x[0];"));
+        assertEquals(false, exec("boolean[] x = new boolean[1]; x[0] = false; x[0] ^= false; return x[0];"));
+
+        // byte
+        assertEquals((byte) (13 ^ 14), exec("byte x = 13; x ^= 14; return x;"));
+        // short
+        assertEquals((short) (13 ^ 14), exec("short x = 13; x ^= 14; return x;"));
+        // char
+        assertEquals((char) (13 ^ 14), exec("char x = 13; x ^= 14; return x;"));
+        // int
+        assertEquals(13 ^ 14, exec("int x = 13; x ^= 14; return x;"));
+        // long
+        assertEquals((long) (13 ^ 14), exec("long x = 13L; x ^= 14; return x;"));
+    }
+    
+    public void testBogusCompoundAssignment() {
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("float x = 4; int y = 1; x ^= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("double x = 4; int y = 1; x ^= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; float y = 1; x ^= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; double y = 1; x ^= y");
+        });
+    }
+    
+    public void testCompoundAssignmentDef() {
+        // boolean
+        assertEquals(false, exec("def x = true; x ^= true; return x;"));
+        assertEquals(true, exec("def x = true; x ^= false; return x;"));
+        assertEquals(true, exec("def x = false; x ^= true; return x;"));
+        assertEquals(false, exec("def x = false; x ^= false; return x;"));
+        assertEquals(false, exec("def[] x = new def[1]; x[0] = true; x[0] ^= true; return x[0];"));
+        assertEquals(true, exec("def[] x = new def[1]; x[0] = true; x[0] ^= false; return x[0];"));
+        assertEquals(true, exec("def[] x = new def[1]; x[0] = false; x[0] ^= true; return x[0];"));
+        assertEquals(false, exec("def[] x = new def[1]; x[0] = false; x[0] ^= false; return x[0];"));
+
+        // byte
+        assertEquals((byte) (13 ^ 14), exec("def x = (byte)13; x ^= 14; return x;"));
+        // short
+        assertEquals((short) (13 ^ 14), exec("def x = (short)13; x ^= 14; return x;"));
+        // char
+        assertEquals((char) (13 ^ 14), exec("def x = (char)13; x ^= 14; return x;"));
+        // int
+        assertEquals(13 ^ 14, exec("def x = 13; x ^= 14; return x;"));
+        // long
+        assertEquals((long) (13 ^ 14), exec("def x = 13L; x ^= 14; return x;"));
+    }
+    
+    public void testDefBogusCompoundAssignment() {
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("def x = 4F; int y = 1; x ^= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("def x = 4D; int y = 1; x ^= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; def y = (float)1; x ^= y");
+        });
+        expectScriptThrows(ClassCastException.class, () -> {
+            exec("int x = 4; def y = (double)1; x ^= y");
+        });
+    }
 }