Browse Source

Introduce a phase to use String.equals on constant strings, rather than def equality (#91362)

Add a painless phase to use String.equals on equality checks with constant strings, rather than dynamic def equality
Simon Cooper 2 years ago
parent
commit
d3b35795f2

+ 6 - 0
docs/changelog/91362.yaml

@@ -0,0 +1,6 @@
+pr: 91362
+summary: "Introduce a phase to use String.equals on constant strings, rather than\
+  \ def equality"
+area: Infra/Core
+type: enhancement
+issues: [91235]

+ 4 - 0
modules/lang-painless/src/main/java/org/elasticsearch/painless/Compiler.java

@@ -14,6 +14,7 @@ import org.elasticsearch.painless.ir.ClassNode;
 import org.elasticsearch.painless.lookup.PainlessLookup;
 import org.elasticsearch.painless.node.SClass;
 import org.elasticsearch.painless.phase.DefaultConstantFoldingOptimizationPhase;
+import org.elasticsearch.painless.phase.DefaultEqualityMethodOptimizationPhase;
 import org.elasticsearch.painless.phase.DefaultIRTreeToASMBytesPhase;
 import org.elasticsearch.painless.phase.DefaultStaticConstantExtractionPhase;
 import org.elasticsearch.painless.phase.DefaultStringConcatenationOptimizationPhase;
@@ -217,6 +218,7 @@ final class Compiler {
         ClassNode classNode = (ClassNode) scriptScope.getDecoration(root, IRNodeDecoration.class).irNode();
         new DefaultStringConcatenationOptimizationPhase().visitClass(classNode, null);
         new DefaultConstantFoldingOptimizationPhase().visitClass(classNode, null);
+        new DefaultEqualityMethodOptimizationPhase(scriptScope).visitClass(classNode, null);
         new DefaultStaticConstantExtractionPhase().visitClass(classNode, scriptScope);
         new DefaultIRTreeToASMBytesPhase().visitScript(classNode);
         byte[] bytes = classNode.getBytes();
@@ -252,6 +254,7 @@ final class Compiler {
         ClassNode classNode = (ClassNode) scriptScope.getDecoration(root, IRNodeDecoration.class).irNode();
         new DefaultStringConcatenationOptimizationPhase().visitClass(classNode, null);
         new DefaultConstantFoldingOptimizationPhase().visitClass(classNode, null);
+        new DefaultEqualityMethodOptimizationPhase(scriptScope).visitClass(classNode, null);
         new DefaultStaticConstantExtractionPhase().visitClass(classNode, scriptScope);
         classNode.setDebugStream(debugStream);
         new DefaultIRTreeToASMBytesPhase().visitScript(classNode);
@@ -290,6 +293,7 @@ final class Compiler {
         ClassNode classNode = (ClassNode) scriptScope.getDecoration(root, IRNodeDecoration.class).irNode();
         new DefaultStringConcatenationOptimizationPhase().visitClass(classNode, null);
         new DefaultConstantFoldingOptimizationPhase().visitClass(classNode, null);
+        new DefaultEqualityMethodOptimizationPhase(scriptScope).visitClass(classNode, null);
         new DefaultStaticConstantExtractionPhase().visitClass(classNode, scriptScope);
         classNode.setDebugStream(debugStream);
 

+ 14 - 291
modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultConstantFoldingOptimizationPhase.java

@@ -10,51 +10,16 @@ package org.elasticsearch.painless.phase;
 
 import org.elasticsearch.painless.AnalyzerCaster;
 import org.elasticsearch.painless.Operation;
-import org.elasticsearch.painless.ir.BinaryImplNode;
 import org.elasticsearch.painless.ir.BinaryMathNode;
 import org.elasticsearch.painless.ir.BooleanNode;
 import org.elasticsearch.painless.ir.CastNode;
 import org.elasticsearch.painless.ir.ComparisonNode;
-import org.elasticsearch.painless.ir.ConditionalNode;
 import org.elasticsearch.painless.ir.ConstantNode;
-import org.elasticsearch.painless.ir.DeclarationNode;
-import org.elasticsearch.painless.ir.DoWhileLoopNode;
-import org.elasticsearch.painless.ir.DupNode;
-import org.elasticsearch.painless.ir.ElvisNode;
 import org.elasticsearch.painless.ir.ExpressionNode;
-import org.elasticsearch.painless.ir.FlipArrayIndexNode;
-import org.elasticsearch.painless.ir.FlipCollectionIndexNode;
-import org.elasticsearch.painless.ir.FlipDefIndexNode;
-import org.elasticsearch.painless.ir.ForEachSubArrayNode;
-import org.elasticsearch.painless.ir.ForEachSubIterableNode;
-import org.elasticsearch.painless.ir.ForLoopNode;
-import org.elasticsearch.painless.ir.IfElseNode;
-import org.elasticsearch.painless.ir.IfNode;
-import org.elasticsearch.painless.ir.InstanceofNode;
-import org.elasticsearch.painless.ir.InvokeCallDefNode;
 import org.elasticsearch.painless.ir.InvokeCallMemberNode;
-import org.elasticsearch.painless.ir.InvokeCallNode;
-import org.elasticsearch.painless.ir.ListInitializationNode;
-import org.elasticsearch.painless.ir.MapInitializationNode;
-import org.elasticsearch.painless.ir.NewArrayNode;
-import org.elasticsearch.painless.ir.NewObjectNode;
 import org.elasticsearch.painless.ir.NullNode;
-import org.elasticsearch.painless.ir.NullSafeSubNode;
-import org.elasticsearch.painless.ir.ReturnNode;
-import org.elasticsearch.painless.ir.StatementExpressionNode;
-import org.elasticsearch.painless.ir.StoreBraceDefNode;
-import org.elasticsearch.painless.ir.StoreBraceNode;
-import org.elasticsearch.painless.ir.StoreDotDefNode;
-import org.elasticsearch.painless.ir.StoreDotNode;
-import org.elasticsearch.painless.ir.StoreDotShortcutNode;
-import org.elasticsearch.painless.ir.StoreFieldMemberNode;
-import org.elasticsearch.painless.ir.StoreListShortcutNode;
-import org.elasticsearch.painless.ir.StoreMapShortcutNode;
-import org.elasticsearch.painless.ir.StoreVariableNode;
 import org.elasticsearch.painless.ir.StringConcatenationNode;
-import org.elasticsearch.painless.ir.ThrowNode;
 import org.elasticsearch.painless.ir.UnaryMathNode;
-import org.elasticsearch.painless.ir.WhileLoopNode;
 import org.elasticsearch.painless.lookup.PainlessInstanceBinding;
 import org.elasticsearch.painless.lookup.PainlessLookupUtility;
 import org.elasticsearch.painless.lookup.PainlessMethod;
@@ -78,108 +43,13 @@ import java.util.function.Consumer;
  * for a child node to introspect into its parent node, so to replace itself the parent node
  * must pass the child node's particular set method as method reference.
  */
-public class DefaultConstantFoldingOptimizationPhase extends IRTreeBaseVisitor<Consumer<ExpressionNode>> {
-
-    @Override
-    public void visitIf(IfNode irIfNode, Consumer<ExpressionNode> scope) {
-        irIfNode.getConditionNode().visit(this, irIfNode::setConditionNode);
-        irIfNode.getBlockNode().visit(this, null);
-    }
-
-    @Override
-    public void visitIfElse(IfElseNode irIfElseNode, Consumer<ExpressionNode> scope) {
-        irIfElseNode.getConditionNode().visit(this, irIfElseNode::setConditionNode);
-        irIfElseNode.getBlockNode().visit(this, null);
-        irIfElseNode.getElseBlockNode().visit(this, null);
-    }
-
-    @Override
-    public void visitWhileLoop(WhileLoopNode irWhileLoopNode, Consumer<ExpressionNode> scope) {
-        if (irWhileLoopNode.getConditionNode() != null) {
-            irWhileLoopNode.getConditionNode().visit(this, irWhileLoopNode::setConditionNode);
-        }
-
-        if (irWhileLoopNode.getBlockNode() != null) {
-            irWhileLoopNode.getBlockNode().visit(this, null);
-        }
-    }
-
-    @Override
-    public void visitDoWhileLoop(DoWhileLoopNode irDoWhileLoopNode, Consumer<ExpressionNode> scope) {
-        irDoWhileLoopNode.getBlockNode().visit(this, null);
-
-        if (irDoWhileLoopNode.getConditionNode() != null) {
-            irDoWhileLoopNode.getConditionNode().visit(this, irDoWhileLoopNode::setConditionNode);
-        }
-    }
-
-    @Override
-    public void visitForLoop(ForLoopNode irForLoopNode, Consumer<ExpressionNode> scope) {
-        if (irForLoopNode.getInitializerNode() != null) {
-            irForLoopNode.getInitializerNode().visit(this, irForLoopNode::setInitializerNode);
-        }
-
-        if (irForLoopNode.getConditionNode() != null) {
-            irForLoopNode.getConditionNode().visit(this, irForLoopNode::setConditionNode);
-        }
-
-        if (irForLoopNode.getAfterthoughtNode() != null) {
-            irForLoopNode.getAfterthoughtNode().visit(this, irForLoopNode::setAfterthoughtNode);
-        }
-
-        if (irForLoopNode.getBlockNode() != null) {
-            irForLoopNode.getBlockNode().visit(this, null);
-        }
-    }
-
-    @Override
-    public void visitForEachSubArrayLoop(ForEachSubArrayNode irForEachSubArrayNode, Consumer<ExpressionNode> scope) {
-        irForEachSubArrayNode.getConditionNode().visit(this, irForEachSubArrayNode::setConditionNode);
-        irForEachSubArrayNode.getBlockNode().visit(this, null);
-    }
-
-    @Override
-    public void visitForEachSubIterableLoop(ForEachSubIterableNode irForEachSubIterableNode, Consumer<ExpressionNode> scope) {
-        irForEachSubIterableNode.getConditionNode().visit(this, irForEachSubIterableNode::setConditionNode);
-        irForEachSubIterableNode.getBlockNode().visit(this, null);
-    }
-
-    @Override
-    public void visitDeclaration(DeclarationNode irDeclarationNode, Consumer<ExpressionNode> scope) {
-        if (irDeclarationNode.getExpressionNode() != null) {
-            irDeclarationNode.getExpressionNode().visit(this, irDeclarationNode::setExpressionNode);
-        }
-    }
-
-    @Override
-    public void visitReturn(ReturnNode irReturnNode, Consumer<ExpressionNode> scope) {
-        if (irReturnNode.getExpressionNode() != null) {
-            irReturnNode.getExpressionNode().visit(this, irReturnNode::setExpressionNode);
-        }
-    }
-
-    @Override
-    public void visitStatementExpression(StatementExpressionNode irStatementExpressionNode, Consumer<ExpressionNode> scope) {
-        irStatementExpressionNode.getExpressionNode().visit(this, irStatementExpressionNode::setExpressionNode);
-    }
-
-    @Override
-    public void visitThrow(ThrowNode irThrowNode, Consumer<ExpressionNode> scope) {
-        irThrowNode.getExpressionNode().visit(this, irThrowNode::setExpressionNode);
-    }
-
-    @Override
-    public void visitBinaryImpl(BinaryImplNode irBinaryImplNode, Consumer<ExpressionNode> scope) {
-        irBinaryImplNode.getLeftNode().visit(this, irBinaryImplNode::setLeftNode);
-        irBinaryImplNode.getRightNode().visit(this, irBinaryImplNode::setRightNode);
-    }
+public class DefaultConstantFoldingOptimizationPhase extends IRExpressionModifyingVisitor {
 
     @Override
     public void visitUnaryMath(UnaryMathNode irUnaryMathNode, Consumer<ExpressionNode> scope) {
         irUnaryMathNode.getChildNode().visit(this, irUnaryMathNode::setChildNode);
 
-        if (irUnaryMathNode.getChildNode() instanceof ConstantNode) {
-            ConstantNode irConstantNode = (ConstantNode) irUnaryMathNode.getChildNode();
+        if (irUnaryMathNode.getChildNode()instanceof ConstantNode irConstantNode) {
             Object constantValue = irConstantNode.getDecorationValue(IRDConstant.class);
             Operation operation = irUnaryMathNode.getDecorationValue(IRDOperation.class);
             Class<?> type = irUnaryMathNode.getDecorationValue(IRDExpressionType.class);
@@ -269,9 +139,8 @@ public class DefaultConstantFoldingOptimizationPhase extends IRTreeBaseVisitor<C
         irBinaryMathNode.getLeftNode().visit(this, irBinaryMathNode::setLeftNode);
         irBinaryMathNode.getRightNode().visit(this, irBinaryMathNode::setRightNode);
 
-        if (irBinaryMathNode.getLeftNode() instanceof ConstantNode && irBinaryMathNode.getRightNode() instanceof ConstantNode) {
-            ConstantNode irLeftConstantNode = (ConstantNode) irBinaryMathNode.getLeftNode();
-            ConstantNode irRightConstantNode = (ConstantNode) irBinaryMathNode.getRightNode();
+        if (irBinaryMathNode.getLeftNode()instanceof ConstantNode irLeftConstantNode
+            && irBinaryMathNode.getRightNode()instanceof ConstantNode irRightConstantNode) {
             Object leftConstantValue = irLeftConstantNode.getDecorationValue(IRDConstant.class);
             Object rightConstantValue = irRightConstantNode.getDecorationValue(IRDConstant.class);
             Operation operation = irBinaryMathNode.getDecorationValue(IRDOperation.class);
@@ -622,23 +491,20 @@ public class DefaultConstantFoldingOptimizationPhase extends IRTreeBaseVisitor<C
             irRightNode.visit(this, (e) -> irStringConcatenationNode.getArgumentNodes().set(j + 1, e));
 
             if (irLeftNode instanceof ConstantNode && irRightNode instanceof ConstantNode) {
-                ConstantNode irConstantNode = (ConstantNode) irLeftNode;
-                irConstantNode.attachDecoration(
+                irLeftNode.attachDecoration(
                     new IRDConstant(
-                        "" + irConstantNode.getDecorationValue(IRDConstant.class) + irRightNode.getDecorationValue(IRDConstant.class)
+                        "" + irLeftNode.getDecorationValue(IRDConstant.class) + irRightNode.getDecorationValue(IRDConstant.class)
                     )
                 );
-                irConstantNode.attachDecoration(new IRDExpressionType(String.class));
+                irLeftNode.attachDecoration(new IRDExpressionType(String.class));
                 irStringConcatenationNode.getArgumentNodes().remove(i + 1);
             } else if (irLeftNode instanceof NullNode && irRightNode instanceof ConstantNode) {
-                ConstantNode irConstantNode = (ConstantNode) irRightNode;
-                irConstantNode.attachDecoration(new IRDConstant("" + null + irRightNode.getDecorationValue(IRDConstant.class)));
-                irConstantNode.attachDecoration(new IRDExpressionType(String.class));
+                irRightNode.attachDecoration(new IRDConstant("" + null + irRightNode.getDecorationValue(IRDConstant.class)));
+                irRightNode.attachDecoration(new IRDExpressionType(String.class));
                 irStringConcatenationNode.getArgumentNodes().remove(i);
             } else if (irLeftNode instanceof ConstantNode && irRightNode instanceof NullNode) {
-                ConstantNode irConstantNode = (ConstantNode) irLeftNode;
-                irConstantNode.attachDecoration(new IRDConstant("" + irLeftNode.getDecorationValue(IRDConstant.class) + null));
-                irConstantNode.attachDecoration(new IRDExpressionType(String.class));
+                irLeftNode.attachDecoration(new IRDConstant("" + irLeftNode.getDecorationValue(IRDConstant.class) + null));
+                irLeftNode.attachDecoration(new IRDExpressionType(String.class));
                 irStringConcatenationNode.getArgumentNodes().remove(i + 1);
             } else if (irLeftNode instanceof NullNode && irRightNode instanceof NullNode) {
                 ConstantNode irConstantNode = new ConstantNode(irLeftNode.getLocation());
@@ -665,9 +531,8 @@ public class DefaultConstantFoldingOptimizationPhase extends IRTreeBaseVisitor<C
         irBooleanNode.getLeftNode().visit(this, irBooleanNode::setLeftNode);
         irBooleanNode.getRightNode().visit(this, irBooleanNode::setRightNode);
 
-        if (irBooleanNode.getLeftNode() instanceof ConstantNode && irBooleanNode.getRightNode() instanceof ConstantNode) {
-            ConstantNode irLeftConstantNode = (ConstantNode) irBooleanNode.getLeftNode();
-            ConstantNode irRightConstantNode = (ConstantNode) irBooleanNode.getRightNode();
+        if (irBooleanNode.getLeftNode()instanceof ConstantNode irLeftConstantNode
+            && irBooleanNode.getRightNode()instanceof ConstantNode irRightConstantNode) {
             Operation operation = irBooleanNode.getDecorationValue(IRDOperation.class);
             Class<?> type = irBooleanNode.getDecorationValue(IRDExpressionType.class);
 
@@ -946,9 +811,8 @@ public class DefaultConstantFoldingOptimizationPhase extends IRTreeBaseVisitor<C
     public void visitCast(CastNode irCastNode, Consumer<ExpressionNode> scope) {
         irCastNode.getChildNode().visit(this, irCastNode::setChildNode);
 
-        if (irCastNode.getChildNode() instanceof ConstantNode
+        if (irCastNode.getChildNode()instanceof ConstantNode irConstantNode
             && PainlessLookupUtility.isConstantType(irCastNode.getDecorationValue(IRDExpressionType.class))) {
-            ConstantNode irConstantNode = (ConstantNode) irCastNode.getChildNode();
             Object constantValue = irConstantNode.getDecorationValue(IRDConstant.class);
             constantValue = AnalyzerCaster.constCast(irCastNode.getLocation(), constantValue, irCastNode.getDecorationValue(IRDCast.class));
             irConstantNode.attachDecoration(new IRDConstant(constantValue));
@@ -957,127 +821,6 @@ public class DefaultConstantFoldingOptimizationPhase extends IRTreeBaseVisitor<C
         }
     }
 
-    @Override
-    public void visitInstanceof(InstanceofNode irInstanceofNode, Consumer<ExpressionNode> scope) {
-        irInstanceofNode.getChildNode().visit(this, irInstanceofNode::setChildNode);
-    }
-
-    @Override
-    public void visitConditional(ConditionalNode irConditionalNode, Consumer<ExpressionNode> scope) {
-        irConditionalNode.getConditionNode().visit(this, irConditionalNode::setConditionNode);
-        irConditionalNode.getLeftNode().visit(this, irConditionalNode::setLeftNode);
-        irConditionalNode.getRightNode().visit(this, irConditionalNode::setRightNode);
-    }
-
-    @Override
-    public void visitElvis(ElvisNode irElvisNode, Consumer<ExpressionNode> scope) {
-        irElvisNode.getLeftNode().visit(this, irElvisNode::setLeftNode);
-        irElvisNode.getRightNode().visit(this, irElvisNode::setRightNode);
-    }
-
-    @Override
-    public void visitListInitialization(ListInitializationNode irListInitializationNode, Consumer<ExpressionNode> scope) {
-        for (int i = 0; i < irListInitializationNode.getArgumentNodes().size(); i++) {
-            int j = i;
-            irListInitializationNode.getArgumentNodes().get(i).visit(this, (e) -> irListInitializationNode.getArgumentNodes().set(j, e));
-        }
-    }
-
-    @Override
-    public void visitMapInitialization(MapInitializationNode irMapInitializationNode, Consumer<ExpressionNode> scope) {
-        for (int i = 0; i < irMapInitializationNode.getKeyNodes().size(); i++) {
-            int j = i;
-            irMapInitializationNode.getKeyNode(i).visit(this, (e) -> irMapInitializationNode.getKeyNodes().set(j, e));
-        }
-
-        for (int i = 0; i < irMapInitializationNode.getValueNodes().size(); i++) {
-            int j = i;
-            irMapInitializationNode.getValueNode(i).visit(this, (e) -> irMapInitializationNode.getValueNodes().set(j, e));
-        }
-    }
-
-    @Override
-    public void visitNewArray(NewArrayNode irNewArrayNode, Consumer<ExpressionNode> scope) {
-        for (int i = 0; i < irNewArrayNode.getArgumentNodes().size(); i++) {
-            int j = i;
-            irNewArrayNode.getArgumentNodes().get(i).visit(this, (e) -> irNewArrayNode.getArgumentNodes().set(j, e));
-        }
-    }
-
-    @Override
-    public void visitNewObject(NewObjectNode irNewObjectNode, Consumer<ExpressionNode> scope) {
-        for (int i = 0; i < irNewObjectNode.getArgumentNodes().size(); i++) {
-            int j = i;
-            irNewObjectNode.getArgumentNodes().get(i).visit(this, (e) -> irNewObjectNode.getArgumentNodes().set(j, e));
-        }
-    }
-
-    @Override
-    public void visitNullSafeSub(NullSafeSubNode irNullSafeSubNode, Consumer<ExpressionNode> scope) {
-        irNullSafeSubNode.getChildNode().visit(this, irNullSafeSubNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreVariable(StoreVariableNode irStoreVariableNode, Consumer<ExpressionNode> scope) {
-        irStoreVariableNode.getChildNode().visit(this, irStoreVariableNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreDotDef(StoreDotDefNode irStoreDotDefNode, Consumer<ExpressionNode> scope) {
-        irStoreDotDefNode.getChildNode().visit(this, irStoreDotDefNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreDot(StoreDotNode irStoreDotNode, Consumer<ExpressionNode> scope) {
-        irStoreDotNode.getChildNode().visit(this, irStoreDotNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreDotShortcut(StoreDotShortcutNode irDotSubShortcutNode, Consumer<ExpressionNode> scope) {
-        irDotSubShortcutNode.getChildNode().visit(this, irDotSubShortcutNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreListShortcut(StoreListShortcutNode irStoreListShortcutNode, Consumer<ExpressionNode> scope) {
-        irStoreListShortcutNode.getChildNode().visit(this, irStoreListShortcutNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreMapShortcut(StoreMapShortcutNode irStoreMapShortcutNode, Consumer<ExpressionNode> scope) {
-        irStoreMapShortcutNode.getChildNode().visit(this, irStoreMapShortcutNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreFieldMember(StoreFieldMemberNode irStoreFieldMemberNode, Consumer<ExpressionNode> scope) {
-        irStoreFieldMemberNode.getChildNode().visit(this, irStoreFieldMemberNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreBraceDef(StoreBraceDefNode irStoreBraceDefNode, Consumer<ExpressionNode> scope) {
-        irStoreBraceDefNode.getChildNode().visit(this, irStoreBraceDefNode::setChildNode);
-    }
-
-    @Override
-    public void visitStoreBrace(StoreBraceNode irStoreBraceNode, Consumer<ExpressionNode> scope) {
-        irStoreBraceNode.getChildNode().visit(this, irStoreBraceNode::setChildNode);
-    }
-
-    @Override
-    public void visitInvokeCallDef(InvokeCallDefNode irInvokeCallDefNode, Consumer<ExpressionNode> scope) {
-        for (int i = 0; i < irInvokeCallDefNode.getArgumentNodes().size(); i++) {
-            int j = i;
-            irInvokeCallDefNode.getArgumentNodes().get(i).visit(this, (e) -> irInvokeCallDefNode.getArgumentNodes().set(j, e));
-        }
-    }
-
-    @Override
-    public void visitInvokeCall(InvokeCallNode irInvokeCallNode, Consumer<ExpressionNode> scope) {
-        for (int i = 0; i < irInvokeCallNode.getArgumentNodes().size(); i++) {
-            int j = i;
-            irInvokeCallNode.getArgumentNodes().get(i).visit(this, (e) -> irInvokeCallNode.getArgumentNodes().set(j, e));
-        }
-    }
-
     @Override
     public void visitInvokeCallMember(InvokeCallMemberNode irInvokeCallMemberNode, Consumer<ExpressionNode> scope) {
         for (int i = 0; i < irInvokeCallMemberNode.getArgumentNodes().size(); i++) {
@@ -1132,24 +875,4 @@ public class DefaultConstantFoldingOptimizationPhase extends IRTreeBaseVisitor<C
         replacement.attachDecoration(irInvokeCallMemberNode.getDecoration(IRDExpressionType.class));
         scope.accept(replacement);
     }
-
-    @Override
-    public void visitFlipArrayIndex(FlipArrayIndexNode irFlipArrayIndexNode, Consumer<ExpressionNode> scope) {
-        irFlipArrayIndexNode.getChildNode().visit(this, irFlipArrayIndexNode::setChildNode);
-    }
-
-    @Override
-    public void visitFlipCollectionIndex(FlipCollectionIndexNode irFlipCollectionIndexNode, Consumer<ExpressionNode> scope) {
-        irFlipCollectionIndexNode.getChildNode().visit(this, irFlipCollectionIndexNode::setChildNode);
-    }
-
-    @Override
-    public void visitFlipDefIndex(FlipDefIndexNode irFlipDefIndexNode, Consumer<ExpressionNode> scope) {
-        irFlipDefIndexNode.getChildNode().visit(this, irFlipDefIndexNode::setChildNode);
-    }
-
-    @Override
-    public void visitDup(DupNode irDupNode, Consumer<ExpressionNode> scope) {
-        irDupNode.getChildNode().visit(this, irDupNode::setChildNode);
-    }
 }

+ 88 - 0
modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/DefaultEqualityMethodOptimizationPhase.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.painless.phase;
+
+import org.elasticsearch.painless.Location;
+import org.elasticsearch.painless.Operation;
+import org.elasticsearch.painless.ir.BinaryImplNode;
+import org.elasticsearch.painless.ir.ComparisonNode;
+import org.elasticsearch.painless.ir.ExpressionNode;
+import org.elasticsearch.painless.ir.InvokeCallNode;
+import org.elasticsearch.painless.ir.UnaryMathNode;
+import org.elasticsearch.painless.lookup.PainlessMethod;
+import org.elasticsearch.painless.symbol.IRDecorations.IRDConstant;
+import org.elasticsearch.painless.symbol.IRDecorations.IRDExpressionType;
+import org.elasticsearch.painless.symbol.IRDecorations.IRDOperation;
+import org.elasticsearch.painless.symbol.ScriptScope;
+
+import java.util.function.Consumer;
+
+/**
+ * Phases that changes ==/!= to use String.equals when one side is a constant string
+ */
+public class DefaultEqualityMethodOptimizationPhase extends IRExpressionModifyingVisitor {
+
+    private final ScriptScope scriptScope;
+
+    public DefaultEqualityMethodOptimizationPhase(ScriptScope scriptScope) {
+        this.scriptScope = scriptScope;
+    }
+
+    @Override
+    public void visitComparison(ComparisonNode irComparisonNode, Consumer<ExpressionNode> scope) {
+        super.visitComparison(irComparisonNode, scope);
+
+        Operation op = irComparisonNode.getDecorationValue(IRDOperation.class);
+        if (op == Operation.EQ || op == Operation.NE) {
+            ExpressionNode constantNode = null;
+            ExpressionNode argumentNode = null;
+            if (irComparisonNode.getLeftNode().getDecorationValue(IRDConstant.class) instanceof String) {
+                constantNode = irComparisonNode.getLeftNode();
+                argumentNode = irComparisonNode.getRightNode();
+            } else if (irComparisonNode.getRightNode().getDecorationValue(IRDConstant.class) instanceof String) {
+                // it's ok to reorder these, RHS is a constant that has no effect on execution
+                constantNode = irComparisonNode.getRightNode();
+                argumentNode = irComparisonNode.getLeftNode();
+            }
+
+            ExpressionNode node = null;
+            Location loc = irComparisonNode.getLocation();
+            if (constantNode != null) {
+                // call String.equals directly
+                InvokeCallNode invoke = new InvokeCallNode(loc);
+                PainlessMethod method = scriptScope.getPainlessLookup().lookupPainlessMethod(String.class, false, "equals", 1);
+                invoke.setMethod(method);
+                invoke.setBox(String.class);
+                invoke.addArgumentNode(argumentNode);
+                invoke.attachDecoration(new IRDExpressionType(boolean.class));
+
+                BinaryImplNode call = new BinaryImplNode(loc);
+                call.setLeftNode(constantNode);
+                call.setRightNode(invoke);
+                call.attachDecoration(new IRDExpressionType(boolean.class));
+
+                node = call;
+            }
+
+            if (node != null) {
+                if (op == Operation.NE) {
+                    UnaryMathNode not = new UnaryMathNode(loc);
+                    not.setChildNode(node);
+                    not.attachDecoration(new IRDOperation(Operation.NOT));
+                    not.attachDecoration(new IRDExpressionType(boolean.class));
+                    node = not;
+                }
+
+                // replace the comparison with this node
+                scope.accept(node);
+            }
+        }
+    }
+
+}

+ 317 - 0
modules/lang-painless/src/main/java/org/elasticsearch/painless/phase/IRExpressionModifyingVisitor.java

@@ -0,0 +1,317 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.painless.phase;
+
+import org.elasticsearch.painless.ir.BinaryImplNode;
+import org.elasticsearch.painless.ir.BinaryMathNode;
+import org.elasticsearch.painless.ir.BooleanNode;
+import org.elasticsearch.painless.ir.CastNode;
+import org.elasticsearch.painless.ir.ComparisonNode;
+import org.elasticsearch.painless.ir.ConditionalNode;
+import org.elasticsearch.painless.ir.DeclarationNode;
+import org.elasticsearch.painless.ir.DoWhileLoopNode;
+import org.elasticsearch.painless.ir.DupNode;
+import org.elasticsearch.painless.ir.ElvisNode;
+import org.elasticsearch.painless.ir.ExpressionNode;
+import org.elasticsearch.painless.ir.FlipArrayIndexNode;
+import org.elasticsearch.painless.ir.FlipCollectionIndexNode;
+import org.elasticsearch.painless.ir.FlipDefIndexNode;
+import org.elasticsearch.painless.ir.ForEachSubArrayNode;
+import org.elasticsearch.painless.ir.ForEachSubIterableNode;
+import org.elasticsearch.painless.ir.ForLoopNode;
+import org.elasticsearch.painless.ir.IfElseNode;
+import org.elasticsearch.painless.ir.IfNode;
+import org.elasticsearch.painless.ir.InstanceofNode;
+import org.elasticsearch.painless.ir.InvokeCallDefNode;
+import org.elasticsearch.painless.ir.InvokeCallMemberNode;
+import org.elasticsearch.painless.ir.InvokeCallNode;
+import org.elasticsearch.painless.ir.ListInitializationNode;
+import org.elasticsearch.painless.ir.MapInitializationNode;
+import org.elasticsearch.painless.ir.NewArrayNode;
+import org.elasticsearch.painless.ir.NewObjectNode;
+import org.elasticsearch.painless.ir.NullSafeSubNode;
+import org.elasticsearch.painless.ir.ReturnNode;
+import org.elasticsearch.painless.ir.StatementExpressionNode;
+import org.elasticsearch.painless.ir.StoreBraceDefNode;
+import org.elasticsearch.painless.ir.StoreBraceNode;
+import org.elasticsearch.painless.ir.StoreDotDefNode;
+import org.elasticsearch.painless.ir.StoreDotNode;
+import org.elasticsearch.painless.ir.StoreDotShortcutNode;
+import org.elasticsearch.painless.ir.StoreFieldMemberNode;
+import org.elasticsearch.painless.ir.StoreListShortcutNode;
+import org.elasticsearch.painless.ir.StoreMapShortcutNode;
+import org.elasticsearch.painless.ir.StoreVariableNode;
+import org.elasticsearch.painless.ir.StringConcatenationNode;
+import org.elasticsearch.painless.ir.ThrowNode;
+import org.elasticsearch.painless.ir.UnaryMathNode;
+import org.elasticsearch.painless.ir.WhileLoopNode;
+
+import java.util.List;
+import java.util.function.Consumer;
+
+public class IRExpressionModifyingVisitor extends IRTreeBaseVisitor<Consumer<ExpressionNode>> {
+
+    private void visitList(List<ExpressionNode> nodes) {
+        for (int i = 0; i < nodes.size(); i++) {
+            int ii = i;
+            nodes.get(i).visit(this, e -> nodes.set(ii, e));
+        }
+    }
+
+    @Override
+    public void visitIf(IfNode irIfNode, Consumer<ExpressionNode> scope) {
+        irIfNode.getConditionNode().visit(this, irIfNode::setConditionNode);
+        irIfNode.getBlockNode().visit(this, null);
+    }
+
+    @Override
+    public void visitIfElse(IfElseNode irIfElseNode, Consumer<ExpressionNode> scope) {
+        irIfElseNode.getConditionNode().visit(this, irIfElseNode::setConditionNode);
+        irIfElseNode.getBlockNode().visit(this, null);
+        irIfElseNode.getElseBlockNode().visit(this, null);
+    }
+
+    @Override
+    public void visitWhileLoop(WhileLoopNode irWhileLoopNode, Consumer<ExpressionNode> scope) {
+        if (irWhileLoopNode.getConditionNode() != null) {
+            irWhileLoopNode.getConditionNode().visit(this, irWhileLoopNode::setConditionNode);
+        }
+
+        if (irWhileLoopNode.getBlockNode() != null) {
+            irWhileLoopNode.getBlockNode().visit(this, null);
+        }
+    }
+
+    @Override
+    public void visitDoWhileLoop(DoWhileLoopNode irDoWhileLoopNode, Consumer<ExpressionNode> scope) {
+        irDoWhileLoopNode.getBlockNode().visit(this, null);
+
+        if (irDoWhileLoopNode.getConditionNode() != null) {
+            irDoWhileLoopNode.getConditionNode().visit(this, irDoWhileLoopNode::setConditionNode);
+        }
+    }
+
+    @Override
+    public void visitForLoop(ForLoopNode irForLoopNode, Consumer<ExpressionNode> scope) {
+        if (irForLoopNode.getInitializerNode() != null) {
+            irForLoopNode.getInitializerNode().visit(this, irForLoopNode::setInitializerNode);
+        }
+
+        if (irForLoopNode.getConditionNode() != null) {
+            irForLoopNode.getConditionNode().visit(this, irForLoopNode::setConditionNode);
+        }
+
+        if (irForLoopNode.getAfterthoughtNode() != null) {
+            irForLoopNode.getAfterthoughtNode().visit(this, irForLoopNode::setAfterthoughtNode);
+        }
+
+        if (irForLoopNode.getBlockNode() != null) {
+            irForLoopNode.getBlockNode().visit(this, null);
+        }
+    }
+
+    @Override
+    public void visitForEachSubArrayLoop(ForEachSubArrayNode irForEachSubArrayNode, Consumer<ExpressionNode> scope) {
+        irForEachSubArrayNode.getConditionNode().visit(this, irForEachSubArrayNode::setConditionNode);
+        irForEachSubArrayNode.getBlockNode().visit(this, null);
+    }
+
+    @Override
+    public void visitForEachSubIterableLoop(ForEachSubIterableNode irForEachSubIterableNode, Consumer<ExpressionNode> scope) {
+        irForEachSubIterableNode.getConditionNode().visit(this, irForEachSubIterableNode::setConditionNode);
+        irForEachSubIterableNode.getBlockNode().visit(this, null);
+    }
+
+    @Override
+    public void visitDeclaration(DeclarationNode irDeclarationNode, Consumer<ExpressionNode> scope) {
+        if (irDeclarationNode.getExpressionNode() != null) {
+            irDeclarationNode.getExpressionNode().visit(this, irDeclarationNode::setExpressionNode);
+        }
+    }
+
+    @Override
+    public void visitReturn(ReturnNode irReturnNode, Consumer<ExpressionNode> scope) {
+        if (irReturnNode.getExpressionNode() != null) {
+            irReturnNode.getExpressionNode().visit(this, irReturnNode::setExpressionNode);
+        }
+    }
+
+    @Override
+    public void visitStatementExpression(StatementExpressionNode irStatementExpressionNode, Consumer<ExpressionNode> scope) {
+        irStatementExpressionNode.getExpressionNode().visit(this, irStatementExpressionNode::setExpressionNode);
+    }
+
+    @Override
+    public void visitThrow(ThrowNode irThrowNode, Consumer<ExpressionNode> scope) {
+        irThrowNode.getExpressionNode().visit(this, irThrowNode::setExpressionNode);
+    }
+
+    @Override
+    public void visitBinaryImpl(BinaryImplNode irBinaryImplNode, Consumer<ExpressionNode> scope) {
+        irBinaryImplNode.getLeftNode().visit(this, irBinaryImplNode::setLeftNode);
+        irBinaryImplNode.getRightNode().visit(this, irBinaryImplNode::setRightNode);
+    }
+
+    @Override
+    public void visitUnaryMath(UnaryMathNode irUnaryMathNode, Consumer<ExpressionNode> scope) {
+        irUnaryMathNode.getChildNode().visit(this, irUnaryMathNode::setChildNode);
+    }
+
+    @Override
+    public void visitBinaryMath(BinaryMathNode irBinaryMathNode, Consumer<ExpressionNode> scope) {
+        irBinaryMathNode.getLeftNode().visit(this, irBinaryMathNode::setLeftNode);
+        irBinaryMathNode.getRightNode().visit(this, irBinaryMathNode::setRightNode);
+    }
+
+    @Override
+    public void visitStringConcatenation(StringConcatenationNode irStringConcatenationNode, Consumer<ExpressionNode> scope) {
+        visitList(irStringConcatenationNode.getArgumentNodes());
+    }
+
+    @Override
+    public void visitBoolean(BooleanNode irBooleanNode, Consumer<ExpressionNode> scope) {
+        irBooleanNode.getLeftNode().visit(this, irBooleanNode::setLeftNode);
+        irBooleanNode.getRightNode().visit(this, irBooleanNode::setRightNode);
+    }
+
+    @Override
+    public void visitComparison(ComparisonNode irComparisonNode, Consumer<ExpressionNode> scope) {
+        irComparisonNode.getLeftNode().visit(this, irComparisonNode::setLeftNode);
+        irComparisonNode.getRightNode().visit(this, irComparisonNode::setRightNode);
+    }
+
+    @Override
+    public void visitCast(CastNode irCastNode, Consumer<ExpressionNode> scope) {
+        irCastNode.getChildNode().visit(this, irCastNode::setChildNode);
+    }
+
+    @Override
+    public void visitInstanceof(InstanceofNode irInstanceofNode, Consumer<ExpressionNode> scope) {
+        irInstanceofNode.getChildNode().visit(this, irInstanceofNode::setChildNode);
+    }
+
+    @Override
+    public void visitConditional(ConditionalNode irConditionalNode, Consumer<ExpressionNode> scope) {
+        irConditionalNode.getConditionNode().visit(this, irConditionalNode::setConditionNode);
+        irConditionalNode.getLeftNode().visit(this, irConditionalNode::setLeftNode);
+        irConditionalNode.getRightNode().visit(this, irConditionalNode::setRightNode);
+    }
+
+    @Override
+    public void visitElvis(ElvisNode irElvisNode, Consumer<ExpressionNode> scope) {
+        irElvisNode.getLeftNode().visit(this, irElvisNode::setLeftNode);
+        irElvisNode.getRightNode().visit(this, irElvisNode::setRightNode);
+    }
+
+    @Override
+    public void visitListInitialization(ListInitializationNode irListInitializationNode, Consumer<ExpressionNode> scope) {
+        visitList(irListInitializationNode.getArgumentNodes());
+    }
+
+    @Override
+    public void visitMapInitialization(MapInitializationNode irMapInitializationNode, Consumer<ExpressionNode> scope) {
+        visitList(irMapInitializationNode.getKeyNodes());
+        visitList(irMapInitializationNode.getValueNodes());
+    }
+
+    @Override
+    public void visitNewArray(NewArrayNode irNewArrayNode, Consumer<ExpressionNode> scope) {
+        visitList(irNewArrayNode.getArgumentNodes());
+    }
+
+    @Override
+    public void visitNewObject(NewObjectNode irNewObjectNode, Consumer<ExpressionNode> scope) {
+        visitList(irNewObjectNode.getArgumentNodes());
+    }
+
+    @Override
+    public void visitNullSafeSub(NullSafeSubNode irNullSafeSubNode, Consumer<ExpressionNode> scope) {
+        irNullSafeSubNode.getChildNode().visit(this, irNullSafeSubNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreVariable(StoreVariableNode irStoreVariableNode, Consumer<ExpressionNode> scope) {
+        irStoreVariableNode.getChildNode().visit(this, irStoreVariableNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreDotDef(StoreDotDefNode irStoreDotDefNode, Consumer<ExpressionNode> scope) {
+        irStoreDotDefNode.getChildNode().visit(this, irStoreDotDefNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreDot(StoreDotNode irStoreDotNode, Consumer<ExpressionNode> scope) {
+        irStoreDotNode.getChildNode().visit(this, irStoreDotNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreDotShortcut(StoreDotShortcutNode irDotSubShortcutNode, Consumer<ExpressionNode> scope) {
+        irDotSubShortcutNode.getChildNode().visit(this, irDotSubShortcutNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreListShortcut(StoreListShortcutNode irStoreListShortcutNode, Consumer<ExpressionNode> scope) {
+        irStoreListShortcutNode.getChildNode().visit(this, irStoreListShortcutNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreMapShortcut(StoreMapShortcutNode irStoreMapShortcutNode, Consumer<ExpressionNode> scope) {
+        irStoreMapShortcutNode.getChildNode().visit(this, irStoreMapShortcutNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreFieldMember(StoreFieldMemberNode irStoreFieldMemberNode, Consumer<ExpressionNode> scope) {
+        irStoreFieldMemberNode.getChildNode().visit(this, irStoreFieldMemberNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreBraceDef(StoreBraceDefNode irStoreBraceDefNode, Consumer<ExpressionNode> scope) {
+        irStoreBraceDefNode.getChildNode().visit(this, irStoreBraceDefNode::setChildNode);
+    }
+
+    @Override
+    public void visitStoreBrace(StoreBraceNode irStoreBraceNode, Consumer<ExpressionNode> scope) {
+        irStoreBraceNode.getChildNode().visit(this, irStoreBraceNode::setChildNode);
+    }
+
+    @Override
+    public void visitInvokeCallDef(InvokeCallDefNode irInvokeCallDefNode, Consumer<ExpressionNode> scope) {
+        visitList(irInvokeCallDefNode.getArgumentNodes());
+    }
+
+    @Override
+    public void visitInvokeCall(InvokeCallNode irInvokeCallNode, Consumer<ExpressionNode> scope) {
+        visitList(irInvokeCallNode.getArgumentNodes());
+    }
+
+    @Override
+    public void visitInvokeCallMember(InvokeCallMemberNode irInvokeCallMemberNode, Consumer<ExpressionNode> scope) {
+        visitList(irInvokeCallMemberNode.getArgumentNodes());
+    }
+
+    @Override
+    public void visitFlipArrayIndex(FlipArrayIndexNode irFlipArrayIndexNode, Consumer<ExpressionNode> scope) {
+        irFlipArrayIndexNode.getChildNode().visit(this, irFlipArrayIndexNode::setChildNode);
+    }
+
+    @Override
+    public void visitFlipCollectionIndex(FlipCollectionIndexNode irFlipCollectionIndexNode, Consumer<ExpressionNode> scope) {
+        irFlipCollectionIndexNode.getChildNode().visit(this, irFlipCollectionIndexNode::setChildNode);
+    }
+
+    @Override
+    public void visitFlipDefIndex(FlipDefIndexNode irFlipDefIndexNode, Consumer<ExpressionNode> scope) {
+        irFlipDefIndexNode.getChildNode().visit(this, irFlipDefIndexNode::setChildNode);
+    }
+
+    @Override
+    public void visitDup(DupNode irDupNode, Consumer<ExpressionNode> scope) {
+        irDupNode.getChildNode().visit(this, irDupNode::setChildNode);
+    }
+}

+ 19 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/EqualsTests.java

@@ -183,4 +183,23 @@ public class EqualsTests extends ScriptTestCase {
         assertEquals(true, exec("HashMap a = new HashMap(); return null != a;"));
         assertEquals(true, exec("HashMap a = new HashMap(); return null !== a;"));
     }
+
+    public void testStringEquals() {
+        assertEquals(false, exec("def x = null; return \"a\" == x"));
+        assertEquals(true, exec("def x = \"a\"; return \"a\" == x"));
+        assertEquals(true, exec("def x = null; return \"a\" != x"));
+        assertEquals(false, exec("def x = \"a\"; return \"a\" != x"));
+
+        assertEquals(false, exec("def x = null; return x == \"a\""));
+        assertEquals(true, exec("def x = \"a\"; return x == \"a\""));
+        assertEquals(true, exec("def x = null; return x != \"a\""));
+        assertEquals(false, exec("def x = \"a\"; return x != \"a\""));
+    }
+
+    public void testStringEqualsMethodCall() {
+        assertBytecodeExists("def x = \"a\"; return \"a\" == x", "INVOKEVIRTUAL java/lang/Object.equals (Ljava/lang/Object;)Z");
+        assertBytecodeExists("def x = \"a\"; return \"a\" != x", "INVOKEVIRTUAL java/lang/Object.equals (Ljava/lang/Object;)Z");
+        assertBytecodeExists("def x = \"a\"; return x == \"a\"", "INVOKEVIRTUAL java/lang/Object.equals (Ljava/lang/Object;)Z");
+        assertBytecodeExists("def x = \"a\"; return x != \"a\"", "INVOKEVIRTUAL java/lang/Object.equals (Ljava/lang/Object;)Z");
+    }
 }