Procházet zdrojové kódy

Simpler regex constants in painless (#68486)

Replaces the double `Pattern.compile` invocations in painless scripts
with the fancy constant injection we added in #68088. This caused one of
the tests to fail. It turns out that we weren't fully iterating the IR
tree during the constant folding phases. I started experimenting and
added a ton of tests that failed. Then I fixed them by changing the IR
tree walking code.
Nik Everett před 4 roky
rodič
revize
e686e18819
14 změnil soubory, kde provedl 80 přidání a 95 odebrání
  1. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreBraceDefNode.java
  2. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreBraceNode.java
  3. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreDotDefNode.java
  4. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreDotNode.java
  5. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreDotShortcutNode.java
  6. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreFieldMemberNode.java
  7. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreListShortcutNode.java
  8. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreMapShortcutNode.java
  9. 1 1
      modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreVariableNode.java
  10. 4 0
      modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java
  11. 12 11
      modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java
  12. 5 73
      modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java
  13. 24 0
      modules/lang-painless/src/test/java/org/elasticsearch/painless/ConstantFoldingTests.java
  14. 26 2
      modules/lang-painless/src/test/java/org/elasticsearch/painless/RegexTests.java

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreBraceDefNode.java

@@ -22,7 +22,7 @@ public class StoreBraceDefNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreBraceNode.java

@@ -22,7 +22,7 @@ public class StoreBraceNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreDotDefNode.java

@@ -22,7 +22,7 @@ public class StoreDotDefNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreDotNode.java

@@ -22,7 +22,7 @@ public class StoreDotNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreDotShortcutNode.java

@@ -22,7 +22,7 @@ public class StoreDotShortcutNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreFieldMemberNode.java

@@ -27,7 +27,7 @@ public class StoreFieldMemberNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreListShortcutNode.java

@@ -22,7 +22,7 @@ public class StoreListShortcutNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreMapShortcutNode.java

@@ -22,7 +22,7 @@ public class StoreMapShortcutNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 1 - 1
modules/lang-painless/src/main/java/org/elasticsearch/painless/ir/StoreVariableNode.java

@@ -22,7 +22,7 @@ public class StoreVariableNode extends UnaryNode {
 
     @Override
     public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
-        // do nothing; terminal node
+        getChildNode().visit(irTreeVisitor, scope);
     }
 
     /* ---- end visitor ---- */

+ 4 - 0
modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultIRTreeToASMBytesPhase.java

@@ -1216,6 +1216,10 @@ public class DefaultIRTreeToASMBytesPhase implements IRTreeVisitor<WriteScope> {
              */
             String fieldName = irConstantNode.getDecorationValue(IRDConstantFieldName.class);
             Type asmFieldType = MethodWriter.getType(irConstantNode.getDecorationValue(IRDExpressionType.class));
+            if (asmFieldType == null) {
+                throw irConstantNode.getLocation()
+                    .createError(new IllegalStateException("Didn't attach constant to [" + irConstantNode + "]"));
+            }
             methodWriter.getStatic(CLASS_TYPE, fieldName, asmFieldType);
         }
     }

+ 12 - 11
modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultSemanticAnalysisPhase.java

@@ -2051,43 +2051,44 @@ public class DefaultSemanticAnalysisPhase extends UserTreeBaseVisitor<SemanticSc
 
         Location location = userRegexNode.getLocation();
 
-        int constant = 0;
+        int regexFlags = 0;
 
         for (int i = 0; i < flags.length(); ++i) {
             char flag = flags.charAt(i);
 
             switch (flag) {
                 case 'c':
-                    constant |= Pattern.CANON_EQ;
+                    regexFlags |= Pattern.CANON_EQ;
                     break;
                 case 'i':
-                    constant |= Pattern.CASE_INSENSITIVE;
+                    regexFlags |= Pattern.CASE_INSENSITIVE;
                     break;
                 case 'l':
-                    constant |= Pattern.LITERAL;
+                    regexFlags |= Pattern.LITERAL;
                     break;
                 case 'm':
-                    constant |= Pattern.MULTILINE;
+                    regexFlags |= Pattern.MULTILINE;
                     break;
                 case 's':
-                    constant |= Pattern.DOTALL;
+                    regexFlags |= Pattern.DOTALL;
                     break;
                 case 'U':
-                    constant |= Pattern.UNICODE_CHARACTER_CLASS;
+                    regexFlags |= Pattern.UNICODE_CHARACTER_CLASS;
                     break;
                 case 'u':
-                    constant |= Pattern.UNICODE_CASE;
+                    regexFlags |= Pattern.UNICODE_CASE;
                     break;
                 case 'x':
-                    constant |= Pattern.COMMENTS;
+                    regexFlags |= Pattern.COMMENTS;
                     break;
                 default:
                     throw new IllegalArgumentException("invalid regular expression: unknown flag [" + flag + "]");
             }
         }
 
+        Pattern compiled;
         try {
-            Pattern.compile(pattern, constant);
+            compiled = Pattern.compile(pattern, regexFlags);
         } catch (PatternSyntaxException pse) {
             throw new Location(location.getSourceName(), location.getOffset() + 1 + pse.getIndex()).createError(
                     new IllegalArgumentException("invalid regular expression: " +
@@ -2095,7 +2096,7 @@ public class DefaultSemanticAnalysisPhase extends UserTreeBaseVisitor<SemanticSc
         }
 
         semanticScope.putDecoration(userRegexNode, new ValueType(Pattern.class));
-        semanticScope.putDecoration(userRegexNode, new StandardConstant(constant));
+        semanticScope.putDecoration(userRegexNode, new StandardConstant(compiled));
     }
 
     /**

+ 5 - 73
modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultUserTreeToIRTreePhase.java

@@ -75,7 +75,6 @@ import org.elasticsearch.painless.ir.StoreBraceNode;
 import org.elasticsearch.painless.ir.StoreDotDefNode;
 import org.elasticsearch.painless.ir.StoreDotNode;
 import org.elasticsearch.painless.ir.StoreDotShortcutNode;
-import org.elasticsearch.painless.ir.StoreFieldMemberNode;
 import org.elasticsearch.painless.ir.StoreListShortcutNode;
 import org.elasticsearch.painless.ir.StoreMapShortcutNode;
 import org.elasticsearch.painless.ir.StoreVariableNode;
@@ -206,8 +205,8 @@ import org.elasticsearch.painless.symbol.IRDecorations.IRDComparisonType;
 import org.elasticsearch.painless.symbol.IRDecorations.IRDConstant;
 import org.elasticsearch.painless.symbol.IRDecorations.IRDConstructor;
 import org.elasticsearch.painless.symbol.IRDecorations.IRDDeclarationType;
-import org.elasticsearch.painless.symbol.IRDecorations.IRDDepth;
 import org.elasticsearch.painless.symbol.IRDecorations.IRDDefReferenceEncoding;
+import org.elasticsearch.painless.symbol.IRDecorations.IRDDepth;
 import org.elasticsearch.painless.symbol.IRDecorations.IRDExceptionType;
 import org.elasticsearch.painless.symbol.IRDecorations.IRDExpressionType;
 import org.elasticsearch.painless.symbol.IRDecorations.IRDField;
@@ -1321,77 +1320,10 @@ public class DefaultUserTreeToIRTreePhase implements UserTreeVisitor<ScriptScope
 
     @Override
     public void visitRegex(ERegex userRegexNode, ScriptScope scriptScope) {
-        String memberFieldName = scriptScope.getNextSyntheticName("regex");
-
-        FieldNode irFieldNode = new FieldNode(userRegexNode.getLocation());
-        irFieldNode.attachDecoration(new IRDModifiers(Modifier.FINAL | Modifier.STATIC | Modifier.PRIVATE));
-        irFieldNode.attachDecoration(new IRDFieldType(Pattern.class));
-        irFieldNode.attachDecoration(new IRDName(memberFieldName));
-
-        irClassNode.addFieldNode(irFieldNode);
-
-        try {
-            StatementExpressionNode irStatementExpressionNode = new StatementExpressionNode(userRegexNode.getLocation());
-
-            BlockNode blockNode = irClassNode.getClinitBlockNode();
-            blockNode.addStatementNode(irStatementExpressionNode);
-
-            StoreFieldMemberNode irStoreFieldMemberNode = new StoreFieldMemberNode(userRegexNode.getLocation());
-            irStoreFieldMemberNode.attachDecoration(new IRDExpressionType(void.class));
-            irStoreFieldMemberNode.attachDecoration(new IRDStoreType(Pattern.class));
-            irStoreFieldMemberNode.attachDecoration(new IRDName(memberFieldName));
-            irStoreFieldMemberNode.attachCondition(IRCStatic.class);
-
-            irStatementExpressionNode.setExpressionNode(irStoreFieldMemberNode);
-
-            BinaryImplNode irBinaryImplNode = new BinaryImplNode(userRegexNode.getLocation());
-            irBinaryImplNode.attachDecoration(new IRDExpressionType(Pattern.class));
-
-            irStoreFieldMemberNode.setChildNode(irBinaryImplNode);
-
-            StaticNode irStaticNode = new StaticNode(userRegexNode.getLocation());
-            irStaticNode.attachDecoration(new IRDExpressionType(Pattern.class));
-
-            irBinaryImplNode.setLeftNode(irStaticNode);
-
-            InvokeCallNode invokeCallNode = new InvokeCallNode(userRegexNode.getLocation());
-            invokeCallNode.attachDecoration(new IRDExpressionType(Pattern.class));
-            invokeCallNode.setBox(Pattern.class);
-            invokeCallNode.setMethod(new PainlessMethod(
-                            Pattern.class.getMethod("compile", String.class, int.class),
-                            Pattern.class,
-                            Pattern.class,
-                            Arrays.asList(String.class, int.class),
-                            null,
-                            null,
-                            null
-                    )
-            );
-
-            irBinaryImplNode.setRightNode(invokeCallNode);
-
-            ConstantNode irConstantNode = new ConstantNode(userRegexNode.getLocation());
-            irConstantNode.attachDecoration(new IRDExpressionType(String.class));
-            irConstantNode.attachDecoration(new IRDConstant(userRegexNode.getPattern()));
-
-            invokeCallNode.addArgumentNode(irConstantNode);
-
-            irConstantNode = new ConstantNode(userRegexNode.getLocation());
-            irConstantNode.attachDecoration(new IRDExpressionType(int.class));
-            irConstantNode.attachDecoration(
-                    new IRDConstant(scriptScope.getDecoration(userRegexNode, StandardConstant.class).getStandardConstant()));
-
-            invokeCallNode.addArgumentNode(irConstantNode);
-        } catch (Exception exception) {
-            throw userRegexNode.createError(new IllegalStateException("illegal tree structure"));
-        }
-
-        LoadFieldMemberNode irLoadFieldMemberNode = new LoadFieldMemberNode(userRegexNode.getLocation());
-        irLoadFieldMemberNode.attachDecoration(new IRDExpressionType(Pattern.class));
-        irLoadFieldMemberNode.attachDecoration(new IRDName(memberFieldName));
-        irLoadFieldMemberNode.attachCondition(IRCStatic.class);
-
-        scriptScope.putDecoration(userRegexNode, new IRNodeDecoration(irLoadFieldMemberNode));
+        ConstantNode constant = new ConstantNode(userRegexNode.getLocation());
+        constant.attachDecoration(new IRDExpressionType(Pattern.class));
+        constant.attachDecoration(new IRDConstant(scriptScope.getDecoration(userRegexNode, StandardConstant.class).getStandardConstant()));
+        scriptScope.putDecoration(userRegexNode, new IRNodeDecoration(constant));
     }
 
     @Override

+ 24 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/ConstantFoldingTests.java

@@ -106,4 +106,28 @@ public class ConstantFoldingTests extends ScriptTestCase {
         assertBytecodeExists("2+'2D'", "LDC \"22D\"");
         assertBytecodeExists("4L<5F", "ICONST_1");
     }
+
+    public void testStoreInMap()  {
+        assertBytecodeExists("Map m = [:]; m.a = 1 + 1; m.a", "ICONST_2");
+    }
+
+    public void testStoreInMapDef()  {
+        assertBytecodeExists("def m = [:]; m.a = 1 + 1; m.a", "ICONST_2");
+    }
+
+    public void testStoreInList()  {
+        assertBytecodeExists("List l = [null]; l.0 = 1 + 1; l.0", "ICONST_2");
+    }
+
+    public void testStoreInListDef()  {
+        assertBytecodeExists("def l = [null]; l.0 = 1 + 1; l.0", "ICONST_2");
+    }
+
+    public void testStoreInArray()  {
+        assertBytecodeExists("int[] a = new int[1]; a[0] = 1 + 1; a[0]", "ICONST_2");
+    }
+
+    public void testStoreInArrayDef()  {
+        assertBytecodeExists("def a = new int[1]; a[0] = 1 + 1; a[0]", "ICONST_2");
+    }
 }

+ 26 - 2
modules/lang-painless/src/test/java/org/elasticsearch/painless/RegexTests.java

@@ -59,9 +59,9 @@ public class RegexTests extends ScriptTestCase {
 
     public void testInTernaryCondition()  {
         assertEquals(true, exec("return /foo/.matcher('foo').matches() ? true : false"));
-        assertEquals(1, exec("def i = 0; i += /foo/.matcher('foo').matches() ? 1 : 1; return i"));
+        assertEquals(1, exec("def i = 0; i += /foo/.matcher('foo').matches() ? 1 : 0; return i"));
         assertEquals(true, exec("return 'foo' ==~ /foo/ ? true : false"));
-        assertEquals(1, exec("def i = 0; i += 'foo' ==~ /foo/ ? 1 : 1; return i"));
+        assertEquals(1, exec("def i = 0; i += 'foo' ==~ /foo/ ? 1 : 0; return i"));
     }
 
     public void testInTernaryTrueArm()  {
@@ -232,6 +232,30 @@ public class RegexTests extends ScriptTestCase {
                 exec("'the quick brown fox'.replaceFirst(/[aeiou]/, m -> '$' + m.group().toUpperCase(Locale.ROOT))"));
     }
 
+    public void testStoreInMap()  {
+        assertEquals(true, exec("Map m = [:]; m.a = /foo/; m.a.matcher('foo').matches()"));
+    }
+
+    public void testStoreInMapDef()  {
+        assertEquals(true, exec("def m = [:]; m.a = /foo/; m.a.matcher('foo').matches()"));
+    }
+
+    public void testStoreInList()  {
+        assertEquals(true, exec("List l = [null]; l.0 = /foo/; l.0.matcher('foo').matches()"));
+    }
+
+    public void testStoreInListDef()  {
+        assertEquals(true, exec("def l = [null]; l.0 = /foo/; l.0.matcher('foo').matches()"));
+    }
+
+    public void testStoreInArray()  {
+        assertEquals(true, exec("Pattern[] a = new Pattern[1]; a[0] = /foo/; a[0].matcher('foo').matches()"));
+    }
+
+    public void testStoreInArrayDef()  {
+        assertEquals(true, exec("def a = new Pattern[1]; a[0] = /foo/; a[0].matcher('foo').matches()"));
+    }
+
     public void testCantUsePatternCompile() {
         IllegalArgumentException e = expectScriptThrows(IllegalArgumentException.class, () -> {
             exec("Pattern.compile('aa')");