Răsfoiți Sursa

Use an explicit null check for null receivers, rather than an NPE (#91347)

Use an explicit null check for check the receiver of nullable painless ops & then fallback, rather than an NPE
Simon Cooper 2 ani în urmă
părinte
comite
ebdba71379

+ 5 - 0
docs/changelog/91347.yaml

@@ -0,0 +1,5 @@
+pr: 91347
+summary: "Use an explicit null check for null receivers in painless, rather than an NPE"
+area: Infra/Scripting
+type: enhancement
+issues: [91236]

+ 31 - 12
modules/lang-painless/src/main/java/org/elasticsearch/painless/DefBootstrap.java

@@ -19,6 +19,7 @@ import java.lang.invoke.MethodType;
 import java.lang.invoke.MutableCallSite;
 import java.lang.invoke.WrongMethodTypeException;
 import java.util.Map;
+import java.util.Objects;
 
 /**
  * Painless invokedynamic bootstrap for the call site.
@@ -366,40 +367,42 @@ public final class DefBootstrap {
                 throw exc;
             }
 
-            final MethodHandle test;
+            MethodHandle test;
+            MethodHandle nullCheck = null;
             if (flavor == BINARY_OPERATOR || flavor == SHIFT_OPERATOR) {
                 // some binary operators support nulls, we handle them separate
                 Class<?> clazz0 = args[0] == null ? null : args[0].getClass();
                 Class<?> clazz1 = args[1] == null ? null : args[1].getClass();
                 if (type.parameterType(1) != Object.class) {
                     // case 1: only the receiver is unknown, just check that
+                    MethodType testType = MethodType.methodType(boolean.class, type.parameterType(0));
                     MethodHandle unaryTest = CHECK_LHS.bindTo(clazz0);
-                    test = unaryTest.asType(unaryTest.type().changeParameterType(0, type.parameterType(0)));
+                    test = unaryTest.asType(testType);
+                    nullCheck = NON_NULL.asType(testType);
                 } else if (type.parameterType(0) != Object.class) {
                     // case 2: only the argument is unknown, just check that
+                    MethodType testType = MethodType.methodType(boolean.class, type);
                     MethodHandle unaryTest = CHECK_RHS.bindTo(clazz0).bindTo(clazz1);
-                    test = unaryTest.asType(
-                        unaryTest.type().changeParameterType(0, type.parameterType(0)).changeParameterType(1, type.parameterType(1))
-                    );
+                    test = unaryTest.asType(testType);
+                    nullCheck = MethodHandles.dropArguments(NON_NULL, 0, clazz0).asType(testType);
                 } else {
                     // case 3: check both receiver and argument
+                    MethodType testType = MethodType.methodType(boolean.class, type);
                     MethodHandle binaryTest = CHECK_BOTH.bindTo(clazz0).bindTo(clazz1);
-                    test = binaryTest.asType(
-                        binaryTest.type().changeParameterType(0, type.parameterType(0)).changeParameterType(1, type.parameterType(1))
-                    );
+                    test = binaryTest.asType(testType);
+                    nullCheck = BOTH_NON_NULL.asType(testType);
                 }
             } else {
                 // unary operator
                 MethodHandle receiverTest = CHECK_LHS.bindTo(args[0].getClass());
-                test = receiverTest.asType(receiverTest.type().changeParameterType(0, type.parameterType(0)));
+                test = receiverTest.asType(MethodType.methodType(boolean.class, type.parameterType(0)));
             }
 
             MethodHandle guard = MethodHandles.guardWithTest(test, target, getTarget());
             // very special cases, where even the receiver can be null (see JLS rules for string concat)
-            // we wrap + with an NPE catcher, and use our generic method in that case.
+            // we wrap op with a null check, and use our generic method in that case.
             if (flavor == BINARY_OPERATOR && (flags & OPERATOR_ALLOWS_NULL) != 0) {
-                MethodHandle handler = MethodHandles.dropArguments(lookupGeneric().asType(type()), 0, NullPointerException.class);
-                guard = MethodHandles.catchException(guard, NullPointerException.class, handler);
+                guard = MethodHandles.guardWithTest(nullCheck, guard, lookupGeneric().asType(type()));
             }
 
             initialized = true;
@@ -432,10 +435,20 @@ public final class DefBootstrap {
             return leftObject.getClass() == left && rightObject.getClass() == right;
         }
 
+        /**
+         * Null guard method for caching - ensures both left and right are non-null,
+         * so checkBoth can be called successfully
+         */
+        static boolean bothNonNull(Object leftObject, Object rightObject) {
+            return leftObject != null && rightObject != null;
+        }
+
         private static final MethodHandle CHECK_LHS;
         private static final MethodHandle CHECK_RHS;
         private static final MethodHandle CHECK_BOTH;
         private static final MethodHandle FALLBACK;
+        private static final MethodHandle NON_NULL;
+        private static final MethodHandle BOTH_NON_NULL;
         static {
             final MethodHandles.Lookup methodHandlesLookup = MethodHandles.lookup();
             try {
@@ -459,6 +472,12 @@ public final class DefBootstrap {
                     "fallback",
                     MethodType.methodType(Object.class, Object[].class)
                 );
+                NON_NULL = methodHandlesLookup.findStatic(Objects.class, "nonNull", MethodType.methodType(boolean.class, Object.class));
+                BOTH_NON_NULL = methodHandlesLookup.findStatic(
+                    methodHandlesLookup.lookupClass(),
+                    "bothNonNull",
+                    MethodType.methodType(boolean.class, Object.class, Object.class)
+                );
             } catch (ReflectiveOperationException e) {
                 throw new AssertionError(e);
             }

+ 17 - 1
modules/lang-painless/src/test/java/org/elasticsearch/painless/EqualsTests.java

@@ -10,6 +10,8 @@ package org.elasticsearch.painless;
 
 import org.elasticsearch.test.ESTestCase;
 
+import java.util.Map;
+
 import static java.util.Collections.singletonMap;
 
 public class EqualsTests extends ScriptTestCase {
@@ -200,6 +202,20 @@ public class EqualsTests extends ScriptTestCase {
         assertBytecodeExists("def x = \"a\"; return \"a\" == x", "INVOKEVIRTUAL java/lang/Object.equals (Ljava/lang/Object;)Z");
         assertBytecodeExists("def x = \"a\"; return \"a\" != x", "INVOKEVIRTUAL java/lang/Object.equals (Ljava/lang/Object;)Z");
         assertBytecodeExists("def x = \"a\"; return x == \"a\"", "INVOKEVIRTUAL java/lang/Object.equals (Ljava/lang/Object;)Z");
-        assertBytecodeExists("def x = \"a\"; return x != \"a\"", "INVOKEVIRTUAL java/lang/Object.equals (Ljava/lang/Object;)Z");
+    }
+
+    public void testEqualsNullCheck() {
+        // get the same callsite working once, then with a null
+        // need to specify call site depth as 0 to force MIC to execute
+        assertEquals(false, exec("""
+            def list = [2, 2, 3, 3, 4, null];
+            boolean b;
+            for (int i=0; i<list.length; i+=2) {
+                b = list[i] == list[i+1];
+                b = list[i+1] == 10;
+                b = 10 == list[i+1];
+            }
+            return b;
+            """, Map.of(), Map.of(CompilerSettings.INITIAL_CALL_SITE_DEPTH, "0"), false));
     }
 }

+ 1 - 3
modules/lang-painless/src/test/java/org/elasticsearch/painless/ScriptTestCase.java

@@ -71,9 +71,7 @@ public abstract class ScriptTestCase extends ESTestCase {
 
     /** Compiles and returns the result of {@code script} with access to {@code vars} */
     public Object exec(String script, Map<String, Object> vars, boolean picky) {
-        Map<String, String> compilerSettings = new HashMap<>();
-        compilerSettings.put(CompilerSettings.INITIAL_CALL_SITE_DEPTH, random().nextBoolean() ? "0" : "10");
-        return exec(script, vars, compilerSettings, picky);
+        return exec(script, vars, Map.of(CompilerSettings.INITIAL_CALL_SITE_DEPTH, random().nextBoolean() ? "0" : "10"), picky);
     }
 
     /** Compiles and returns the result of {@code script} with access to {@code vars} and compile-time parameters */