Browse Source

Add a direct sub classes data structure to the Painless lookup (#76955)

This change has two main components.

The first is to have method/field resolution for compile-time and run-time use the same code path for 
now. This removes copying of member methods between super and sub classes and instead does a 
resolution through the class hierarchy. This allows us to correctly implement the next change.

The second is a data structure that allows for the lookup of direct sub classes for all allow listed 
classes/interfaces within Painless.
Jack Conradson 4 years ago
parent
commit
718b1635e2

+ 54 - 32
modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookup.java

@@ -9,6 +9,10 @@
 package org.elasticsearch.painless.lookup;
 
 import java.lang.invoke.MethodHandle;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
@@ -25,6 +29,7 @@ public final class PainlessLookup {
     private final Map<String, Class<?>> javaClassNamesToClasses;
     private final Map<String, Class<?>> canonicalClassNamesToClasses;
     private final Map<Class<?>, PainlessClass> classesToPainlessClasses;
+    private final Map<Class<?>, Set<Class<?>>> classesToDirectSubClasses;
 
     private final Map<String, PainlessMethod> painlessMethodKeysToImportedPainlessMethods;
     private final Map<String, PainlessClassBinding> painlessMethodKeysToPainlessClassBindings;
@@ -34,6 +39,7 @@ public final class PainlessLookup {
             Map<String, Class<?>> javaClassNamesToClasses,
             Map<String, Class<?>> canonicalClassNamesToClasses,
             Map<Class<?>, PainlessClass> classesToPainlessClasses,
+            Map<Class<?>, Set<Class<?>>> classesToDirectSubClasses,
             Map<String, PainlessMethod> painlessMethodKeysToImportedPainlessMethods,
             Map<String, PainlessClassBinding> painlessMethodKeysToPainlessClassBindings,
             Map<String, PainlessInstanceBinding> painlessMethodKeysToPainlessInstanceBindings) {
@@ -41,6 +47,7 @@ public final class PainlessLookup {
         Objects.requireNonNull(javaClassNamesToClasses);
         Objects.requireNonNull(canonicalClassNamesToClasses);
         Objects.requireNonNull(classesToPainlessClasses);
+        Objects.requireNonNull(classesToDirectSubClasses);
 
         Objects.requireNonNull(painlessMethodKeysToImportedPainlessMethods);
         Objects.requireNonNull(painlessMethodKeysToPainlessClassBindings);
@@ -49,6 +56,7 @@ public final class PainlessLookup {
         this.javaClassNamesToClasses = javaClassNamesToClasses;
         this.canonicalClassNamesToClasses = Map.copyOf(canonicalClassNamesToClasses);
         this.classesToPainlessClasses = Map.copyOf(classesToPainlessClasses);
+        this.classesToDirectSubClasses = Map.copyOf(classesToDirectSubClasses);
 
         this.painlessMethodKeysToImportedPainlessMethods = Map.copyOf(painlessMethodKeysToImportedPainlessMethods);
         this.painlessMethodKeysToPainlessClassBindings = Map.copyOf(painlessMethodKeysToPainlessClassBindings);
@@ -75,6 +83,10 @@ public final class PainlessLookup {
         return classesToPainlessClasses.keySet();
     }
 
+    public Set<Class<?>> getDirectSubClasses(Class<?> superClass) {
+        return classesToDirectSubClasses.get(superClass);
+    }
+
     public Set<String> getImportedPainlessMethodsKeys() {
         return painlessMethodKeysToImportedPainlessMethods.keySet();
     }
@@ -142,16 +154,12 @@ public final class PainlessLookup {
             targetClass = typeToBoxedType(targetClass);
         }
 
-        PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetClass);
         String painlessMethodKey = buildPainlessMethodKey(methodName, methodArity);
+        Function<PainlessClass, PainlessMethod> objectLookup = isStatic ?
+                targetPainlessClass -> targetPainlessClass.staticMethods.get(painlessMethodKey) :
+                targetPainlessClass -> targetPainlessClass.methods.get(painlessMethodKey);
 
-        if (targetPainlessClass == null) {
-            return null;
-        }
-
-        return isStatic ?
-                targetPainlessClass.staticMethods.get(painlessMethodKey) :
-                targetPainlessClass.methods.get(painlessMethodKey);
+        return lookupPainlessObject(targetClass, objectLookup);
     }
 
     public PainlessField lookupPainlessField(String targetCanonicalClassName, boolean isStatic, String fieldName) {
@@ -170,22 +178,12 @@ public final class PainlessLookup {
         Objects.requireNonNull(targetClass);
         Objects.requireNonNull(fieldName);
 
-        PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetClass);
         String painlessFieldKey = buildPainlessFieldKey(fieldName);
+        Function<PainlessClass, PainlessField> objectLookup = isStatic ?
+                targetPainlessClass -> targetPainlessClass.staticFields.get(painlessFieldKey) :
+                targetPainlessClass -> targetPainlessClass.fields.get(painlessFieldKey);
 
-        if (targetPainlessClass == null) {
-            return null;
-        }
-
-        PainlessField painlessField = isStatic ?
-                targetPainlessClass.staticFields.get(painlessFieldKey) :
-                targetPainlessClass.fields.get(painlessFieldKey);
-
-        if (painlessField == null) {
-            return null;
-        }
-
-        return painlessField;
+        return lookupPainlessObject(targetClass, objectLookup);
     }
 
     public PainlessMethod lookupImportedPainlessMethod(String methodName, int arity) {
@@ -230,7 +228,7 @@ public final class PainlessLookup {
         Function<PainlessClass, PainlessMethod> objectLookup =
                 targetPainlessClass -> targetPainlessClass.runtimeMethods.get(painlessMethodKey);
 
-        return lookupRuntimePainlessObject(originalTargetClass, objectLookup);
+        return lookupPainlessObject(originalTargetClass, objectLookup);
     }
 
     public MethodHandle lookupRuntimeGetterMethodHandle(Class<?> originalTargetClass, String getterName) {
@@ -239,7 +237,7 @@ public final class PainlessLookup {
 
         Function<PainlessClass, MethodHandle> objectLookup = targetPainlessClass -> targetPainlessClass.getterMethodHandles.get(getterName);
 
-        return lookupRuntimePainlessObject(originalTargetClass, objectLookup);
+        return lookupPainlessObject(originalTargetClass, objectLookup);
     }
 
     public MethodHandle lookupRuntimeSetterMethodHandle(Class<?> originalTargetClass, String setterName) {
@@ -248,10 +246,13 @@ public final class PainlessLookup {
 
         Function<PainlessClass, MethodHandle> objectLookup = targetPainlessClass -> targetPainlessClass.setterMethodHandles.get(setterName);
 
-        return lookupRuntimePainlessObject(originalTargetClass, objectLookup);
+        return lookupPainlessObject(originalTargetClass, objectLookup);
     }
 
-    private <T> T lookupRuntimePainlessObject(Class<?> originalTargetClass, Function<PainlessClass, T> objectLookup) {
+    private <T> T lookupPainlessObject(Class<?> originalTargetClass, Function<PainlessClass, T> objectLookup) {
+        Objects.requireNonNull(originalTargetClass);
+        Objects.requireNonNull(objectLookup);
+
         Class<?> currentTargetClass = originalTargetClass;
 
         while (currentTargetClass != null) {
@@ -268,17 +269,38 @@ public final class PainlessLookup {
             currentTargetClass = currentTargetClass.getSuperclass();
         }
 
+        if (originalTargetClass.isInterface()) {
+            PainlessClass targetPainlessClass = classesToPainlessClasses.get(Object.class);
+
+            if (targetPainlessClass != null) {
+                T painlessObject = objectLookup.apply(targetPainlessClass);
+
+                if (painlessObject != null) {
+                    return painlessObject;
+                }
+            }
+        }
+
         currentTargetClass = originalTargetClass;
+        Set<Class<?>> resolvedInterfaces = new HashSet<>();
 
         while (currentTargetClass != null) {
-            for (Class<?> targetInterface : currentTargetClass.getInterfaces()) {
-                PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetInterface);
+            List<Class<?>> targetInterfaces = new ArrayList<>(Arrays.asList(currentTargetClass.getInterfaces()));
+
+            while (targetInterfaces.isEmpty() == false) {
+                Class<?> targetInterface = targetInterfaces.remove(0);
+
+                if (resolvedInterfaces.add(targetInterface)) {
+                    PainlessClass targetPainlessClass = classesToPainlessClasses.get(targetInterface);
+
+                    if (targetPainlessClass != null) {
+                        T painlessObject = objectLookup.apply(targetPainlessClass);
 
-                if (targetPainlessClass != null) {
-                    T painlessObject = objectLookup.apply(targetPainlessClass);
+                        if (painlessObject != null) {
+                            return painlessObject;
+                        }
 
-                    if (painlessObject != null) {
-                        return painlessObject;
+                        targetInterfaces.addAll(Arrays.asList(targetInterface.getInterfaces()));
                     }
                 }
             }

+ 53 - 53
modules/lang-painless/src/main/java/org/elasticsearch/painless/lookup/PainlessLookupBuilder.java

@@ -42,11 +42,14 @@ import java.security.PrivilegedAction;
 import java.security.SecureClassLoader;
 import java.security.cert.Certificate;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 import java.util.function.Supplier;
 import java.util.regex.Pattern;
 
@@ -189,6 +192,7 @@ public final class PainlessLookupBuilder {
     // of the values of javaClassNamesToClasses.
     private final Map<String, Class<?>> canonicalClassNamesToClasses;
     private final Map<Class<?>, PainlessClassBuilder> classesToPainlessClassBuilders;
+    private final Map<Class<?>, Set<Class<?>>> classesToDirectSubClasses;
 
     private final Map<String, PainlessMethod> painlessMethodKeysToImportedPainlessMethods;
     private final Map<String, PainlessClassBinding> painlessMethodKeysToPainlessClassBindings;
@@ -198,6 +202,7 @@ public final class PainlessLookupBuilder {
         javaClassNamesToClasses = new HashMap<>();
         canonicalClassNamesToClasses = new HashMap<>();
         classesToPainlessClassBuilders = new HashMap<>();
+        classesToDirectSubClasses = new HashMap<>();
 
         painlessMethodKeysToImportedPainlessMethods = new HashMap<>();
         painlessMethodKeysToPainlessClassBindings = new HashMap<>();
@@ -1255,7 +1260,7 @@ public final class PainlessLookupBuilder {
     }
 
     public PainlessLookup build() {
-        copyPainlessClassMembers();
+        buildPainlessClassHierarchy();
         setFunctionalInterfaceMethods();
         generateRuntimeMethods();
         cacheRuntimeHandles();
@@ -1286,71 +1291,66 @@ public final class PainlessLookupBuilder {
                 javaClassNamesToClasses,
                 canonicalClassNamesToClasses,
                 classesToPainlessClasses,
+                classesToDirectSubClasses,
                 painlessMethodKeysToImportedPainlessMethods,
                 painlessMethodKeysToPainlessClassBindings,
                 painlessMethodKeysToPainlessInstanceBindings);
     }
 
-    private void copyPainlessClassMembers() {
-        for (Class<?> parentClass : classesToPainlessClassBuilders.keySet()) {
-            copyPainlessInterfaceMembers(parentClass, parentClass);
-
-            Class<?> childClass = parentClass.getSuperclass();
-
-            while (childClass != null) {
-                if (classesToPainlessClassBuilders.containsKey(childClass)) {
-                    copyPainlessClassMembers(childClass, parentClass);
-                }
-
-                copyPainlessInterfaceMembers(childClass, parentClass);
-                childClass = childClass.getSuperclass();
-            }
-        }
-
-        for (Class<?> javaClass : classesToPainlessClassBuilders.keySet()) {
-            if (javaClass.isInterface()) {
-                copyPainlessClassMembers(Object.class, javaClass);
-            }
-        }
-    }
-
-    private void copyPainlessInterfaceMembers(Class<?> parentClass, Class<?> targetClass) {
-        for (Class<?> childClass : parentClass.getInterfaces()) {
-            if (classesToPainlessClassBuilders.containsKey(childClass)) {
-                copyPainlessClassMembers(childClass, targetClass);
-            }
-
-            copyPainlessInterfaceMembers(childClass, targetClass);
+    private void buildPainlessClassHierarchy() {
+        for (Class<?> targetClass : classesToPainlessClassBuilders.keySet()) {
+            classesToDirectSubClasses.put(targetClass, new HashSet<>());
         }
-    }
 
-    private void copyPainlessClassMembers(Class<?> originalClass, Class<?> targetClass) {
-        PainlessClassBuilder originalPainlessClassBuilder = classesToPainlessClassBuilders.get(originalClass);
-        PainlessClassBuilder targetPainlessClassBuilder = classesToPainlessClassBuilders.get(targetClass);
+        for (Class<?> subClass : classesToPainlessClassBuilders.keySet()) {
+            List<Class<?>> superInterfaces = new ArrayList<>(Arrays.asList(subClass.getInterfaces()));
 
-        Objects.requireNonNull(originalPainlessClassBuilder);
-        Objects.requireNonNull(targetPainlessClassBuilder);
+            // we check for Object.class as part of the allow listed classes because
+            // it is possible for the compiler to work without Object
+            if (subClass.isInterface() && superInterfaces.isEmpty() && classesToPainlessClassBuilders.containsKey(Object.class)) {
+                classesToDirectSubClasses.get(Object.class).add(subClass);
+            } else {
+                Class<?> superClass = subClass.getSuperclass();
+
+                // this finds the nearest super class for a given sub class
+                // because the allow list may have gaps between classes
+                // example:
+                // class A {}        // allowed
+                // class B extends A // not allowed
+                // class C extends B // allowed
+                // in this case C is considered a direct sub class of A
+                while (superClass != null) {
+                    if (classesToPainlessClassBuilders.containsKey(superClass)) {
+                        break;
+                    } else {
+                        // this ensures all interfaces from a sub class that
+                        // is not allow listed are checked if they are
+                        // considered a direct super class of the sub class
+                        // because these interfaces may still be allow listed
+                        // even if their sub class is not
+                        superInterfaces.addAll(Arrays.asList(superClass.getInterfaces()));
+                    }
 
-        for (Map.Entry<String, PainlessMethod> painlessMethodEntry : originalPainlessClassBuilder.methods.entrySet()) {
-            String painlessMethodKey = painlessMethodEntry.getKey();
-            PainlessMethod newPainlessMethod = painlessMethodEntry.getValue();
-            PainlessMethod existingPainlessMethod = targetPainlessClassBuilder.methods.get(painlessMethodKey);
+                    superClass = superClass.getSuperclass();
+                }
 
-            if (existingPainlessMethod == null || existingPainlessMethod.targetClass != newPainlessMethod.targetClass &&
-                    existingPainlessMethod.targetClass.isAssignableFrom(newPainlessMethod.targetClass)) {
-                targetPainlessClassBuilder.methods.put(painlessMethodKey.intern(), newPainlessMethod);
+                if (superClass != null) {
+                    classesToDirectSubClasses.get(superClass).add(subClass);
+                }
             }
-        }
 
-        for (Map.Entry<String, PainlessField> painlessFieldEntry : originalPainlessClassBuilder.fields.entrySet()) {
-            String painlessFieldKey = painlessFieldEntry.getKey();
-            PainlessField newPainlessField = painlessFieldEntry.getValue();
-            PainlessField existingPainlessField = targetPainlessClassBuilder.fields.get(painlessFieldKey);
+            Set<Class<?>> resolvedInterfaces = new HashSet<>();
+
+            while (superInterfaces.isEmpty() == false) {
+                Class<?> superInterface = superInterfaces.remove(0);
 
-            if (existingPainlessField == null ||
-                    existingPainlessField.javaField.getDeclaringClass() != newPainlessField.javaField.getDeclaringClass() &&
-                    existingPainlessField.javaField.getDeclaringClass().isAssignableFrom(newPainlessField.javaField.getDeclaringClass())) {
-                targetPainlessClassBuilder.fields.put(painlessFieldKey.intern(), newPainlessField);
+                if (resolvedInterfaces.add(superInterface)) {
+                    if (classesToPainlessClassBuilders.containsKey(superInterface)) {
+                        classesToDirectSubClasses.get(superInterface).add(subClass);
+                    } else {
+                        superInterfaces.addAll(Arrays.asList(superInterface.getInterfaces()));
+                    }
+                }
             }
         }
     }

+ 116 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/LookupTests.java

@@ -0,0 +1,116 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.painless;
+
+import org.elasticsearch.painless.lookup.PainlessLookup;
+import org.elasticsearch.painless.lookup.PainlessLookupBuilder;
+import org.elasticsearch.painless.spi.WhitelistLoader;
+import org.elasticsearch.test.ESTestCase;
+import org.junit.Before;
+
+import java.util.Collections;
+import java.util.Set;
+
+public class LookupTests extends ESTestCase {
+
+    protected PainlessLookup painlessLookup;
+
+    @Before
+    public void setup() {
+        painlessLookup = PainlessLookupBuilder.buildFromWhitelists(Collections.singletonList(
+                WhitelistLoader.loadFromResourceFiles(PainlessPlugin.class, "org.elasticsearch.painless.lookup")
+        ));
+    }
+
+    public static class A { }           // in whitelist
+    public static class B extends A { } // not in whitelist
+    public static class C extends B { } // in whitelist
+    public static class D extends B { } // in whitelist
+
+    public interface Z { }              // in whitelist
+    public interface Y { }              // not in whitelist
+    public interface X extends Y, Z { } // not in whitelist
+    public interface V extends Y, Z { } // in whitelist
+    public interface U extends X { }    // in whitelist
+    public interface T extends V { }    // in whitelist
+    public interface S extends U, X { } // in whitelist
+
+    public static class AA implements X { }            // in whitelist
+    public static class AB extends AA implements S { } // not in whitelist
+    public static class AC extends AB implements V { } // in whitelist
+    public static class AD implements X, S, T { }      // in whitelist
+
+    public void testDirectSubClasses() {
+        Set<Class<?>> directSubClasses = painlessLookup.getDirectSubClasses(Object.class);
+        assertEquals(4, directSubClasses.size());
+        assertTrue(directSubClasses.contains(A.class));
+        assertTrue(directSubClasses.contains(Z.class));
+        assertTrue(directSubClasses.contains(AA.class));
+        assertTrue(directSubClasses.contains(AD.class));
+
+        directSubClasses = painlessLookup.getDirectSubClasses(A.class);
+        assertEquals(2, directSubClasses.size());
+        assertTrue(directSubClasses.contains(D.class));
+        assertTrue(directSubClasses.contains(C.class));
+
+        directSubClasses = painlessLookup.getDirectSubClasses(B.class);
+        assertNull(directSubClasses);
+
+        directSubClasses = painlessLookup.getDirectSubClasses(C.class);
+        assertTrue(directSubClasses.isEmpty());
+
+        directSubClasses = painlessLookup.getDirectSubClasses(D.class);
+        assertTrue(directSubClasses.isEmpty());
+
+        directSubClasses = painlessLookup.getDirectSubClasses(Z.class);
+        assertEquals(5, directSubClasses.size());
+        assertTrue(directSubClasses.contains(V.class));
+        assertTrue(directSubClasses.contains(U.class));
+        assertTrue(directSubClasses.contains(S.class));
+        assertTrue(directSubClasses.contains(AA.class));
+        assertTrue(directSubClasses.contains(AD.class));
+
+        directSubClasses = painlessLookup.getDirectSubClasses(Y.class);
+        assertNull(directSubClasses);
+
+        directSubClasses = painlessLookup.getDirectSubClasses(X.class);
+        assertNull(directSubClasses);
+
+        directSubClasses = painlessLookup.getDirectSubClasses(V.class);
+        assertEquals(2, directSubClasses.size());
+        assertTrue(directSubClasses.contains(T.class));
+        assertTrue(directSubClasses.contains(AC.class));
+
+        directSubClasses = painlessLookup.getDirectSubClasses(U.class);
+        assertEquals(1, directSubClasses.size());
+        assertTrue(directSubClasses.contains(S.class));
+
+        directSubClasses = painlessLookup.getDirectSubClasses(T.class);
+        assertEquals(1, directSubClasses.size());
+        assertTrue(directSubClasses.contains(AD.class));
+
+        directSubClasses = painlessLookup.getDirectSubClasses(S.class);
+        assertEquals(2, directSubClasses.size());
+        assertTrue(directSubClasses.contains(AC.class));
+        assertTrue(directSubClasses.contains(AD.class));
+
+        directSubClasses = painlessLookup.getDirectSubClasses(AA.class);
+        assertEquals(1, directSubClasses.size());
+        assertTrue(directSubClasses.contains(AC.class));
+
+        directSubClasses = painlessLookup.getDirectSubClasses(AB.class);
+        assertNull(directSubClasses);
+
+        directSubClasses = painlessLookup.getDirectSubClasses(AC.class);
+        assertTrue(directSubClasses.isEmpty());
+
+        directSubClasses = painlessLookup.getDirectSubClasses(AD.class);
+        assertTrue(directSubClasses.isEmpty());
+    }
+}

+ 35 - 0
modules/lang-painless/src/test/resources/org/elasticsearch/painless/org.elasticsearch.painless.lookup

@@ -0,0 +1,35 @@
+class java.lang.Object {
+}
+
+class org.elasticsearch.painless.LookupTests$A {
+}
+
+class org.elasticsearch.painless.LookupTests$C {
+}
+
+class org.elasticsearch.painless.LookupTests$D {
+}
+
+class org.elasticsearch.painless.LookupTests$Z {
+}
+
+class org.elasticsearch.painless.LookupTests$V {
+}
+
+class org.elasticsearch.painless.LookupTests$U {
+}
+
+class org.elasticsearch.painless.LookupTests$T {
+}
+
+class org.elasticsearch.painless.LookupTests$S {
+}
+
+class org.elasticsearch.painless.LookupTests$AA {
+}
+
+class org.elasticsearch.painless.LookupTests$AC {
+}
+
+class org.elasticsearch.painless.LookupTests$AD {
+}