瀏覽代碼

SQL: Prevent StackOverflowError when parsing large statements (#33902)

Implement circuit breaker logic in the parser which catches expressions
that can blow up the tree and result in StackOverflowError being thrown.

Co-authored-by: Costin Leau <costin.leau@gmail.com>
Marios Trivyzas 7 年之前
父節點
當前提交
5840be6a6b

+ 39 - 5
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.sql.parser;
 
+import com.carrotsearch.hppc.ObjectShortHashMap;
 import org.antlr.v4.runtime.BaseErrorListener;
 import org.antlr.v4.runtime.CharStream;
 import org.antlr.v4.runtime.CommonToken;
@@ -22,8 +23,8 @@ import org.antlr.v4.runtime.atn.PredictionMode;
 import org.antlr.v4.runtime.dfa.DFA;
 import org.antlr.v4.runtime.misc.Pair;
 import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.logging.Loggers;
 import org.elasticsearch.xpack.sql.expression.Expression;
 import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
@@ -41,7 +42,8 @@ import java.util.function.Function;
 import static java.lang.String.format;
 
 public class SqlParser {
-    private static final Logger log = Loggers.getLogger(SqlParser.class);
+
+    private static final Logger log = LogManager.getLogger();
 
     private final boolean DEBUG = false;
 
@@ -83,7 +85,9 @@ public class SqlParser {
         return invokeParser(expression, params, SqlBaseParser::singleExpression, AstBuilder::expression);
     }
 
-    private <T> T invokeParser(String sql, List<SqlTypedParamValue> params, Function<SqlBaseParser, ParserRuleContext> parseFunction,
+    private <T> T invokeParser(String sql,
+                               List<SqlTypedParamValue> params, Function<SqlBaseParser,
+                               ParserRuleContext> parseFunction,
                                BiFunction<AstBuilder, ParserRuleContext, T> visitor) {
         SqlBaseLexer lexer = new SqlBaseLexer(new CaseInsensitiveStream(sql));
 
@@ -96,6 +100,7 @@ public class SqlParser {
         CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
         SqlBaseParser parser = new SqlBaseParser(tokenStream);
 
+        parser.addParseListener(new CircuitBreakerListener());
         parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames())));
 
         parser.removeErrorListeners();
@@ -125,7 +130,7 @@ public class SqlParser {
         return visitor.apply(new AstBuilder(paramTokens), tree);
     }
 
-    private void debug(SqlBaseParser parser) {
+    private static void debug(SqlBaseParser parser) {
         
         // when debugging, use the exact prediction mode (needed for diagnostics as well)
         parser.getInterpreter().setPredictionMode(PredictionMode.LL_EXACT_AMBIG_DETECTION);
@@ -154,7 +159,7 @@ public class SqlParser {
         public void exitBackQuotedIdentifier(SqlBaseParser.BackQuotedIdentifierContext context) {
             Token token = context.BACKQUOTED_IDENTIFIER().getSymbol();
             throw new ParsingException(
-                    "backquoted indetifiers not supported; please use double quotes instead",
+                    "backquoted identifiers not supported; please use double quotes instead",
                     null,
                     token.getLine(),
                     token.getCharPositionInLine());
@@ -205,6 +210,35 @@ public class SqlParser {
         }
     }
 
+    /**
+     * Used to catch large expressions that can lead to stack overflows
+     */
+    private class CircuitBreakerListener extends SqlBaseBaseListener {
+
+        private static final short MAX_RULE_DEPTH = 100;
+
+        // Keep current depth for every rule visited.
+        // The totalDepth alone cannot be used as expressions like: e1 OR e2 OR e3 OR ...
+        // are processed as e1 OR (e2 OR (e3 OR (... and this results in the totalDepth not growing
+        // while the stack call depth is, leading to a StackOverflowError.
+        private ObjectShortHashMap<String> depthCounts = new ObjectShortHashMap<>();
+
+        @Override
+        public void enterEveryRule(ParserRuleContext ctx) {
+            short currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 1, (short) 1);
+            if (currentDepth > MAX_RULE_DEPTH) {
+                throw new ParsingException("expression is too large to parse, (tree's depth exceeds {})", MAX_RULE_DEPTH);
+            }
+            super.enterEveryRule(ctx);
+        }
+
+        @Override
+        public void exitEveryRule(ParserRuleContext ctx) {
+            depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 0, (short) -1);
+            super.exitEveryRule(ctx);
+        }
+    }
+
     private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
         @Override
         public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line,

+ 2 - 2
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/QuotingTests.java

@@ -70,7 +70,7 @@ public class QuotingTests extends ESTestCase {
         String name = "@timestamp";
         ParsingException ex = expectThrows(ParsingException.class, () ->
             new SqlParser().createExpression(quote + name + quote));
-        assertThat(ex.getMessage(), equalTo("line 1:1: backquoted indetifiers not supported; please use double quotes instead"));
+        assertThat(ex.getMessage(), equalTo("line 1:1: backquoted identifiers not supported; please use double quotes instead"));
     }
 
     public void testQuotedAttributeAndQualifier() {
@@ -92,7 +92,7 @@ public class QuotingTests extends ESTestCase {
         String name = "@timestamp";
         ParsingException ex = expectThrows(ParsingException.class, () ->
             new SqlParser().createExpression(quote + qualifier + quote + "." + quote + name + quote));
-        assertThat(ex.getMessage(), equalTo("line 1:1: backquoted indetifiers not supported; please use double quotes instead"));
+        assertThat(ex.getMessage(), equalTo("line 1:1: backquoted identifiers not supported; please use double quotes instead"));
     }
 
     public void testGreedyQuoting() {

+ 84 - 0
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.sql.parser;
 
+import com.google.common.base.Joiner;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.sql.expression.NamedExpression;
 import org.elasticsearch.xpack.sql.expression.Order;
@@ -22,6 +23,7 @@ import org.elasticsearch.xpack.sql.plan.logical.Project;
 import java.util.ArrayList;
 import java.util.List;
 
+import static java.util.Collections.nCopies;
 import static java.util.stream.Collectors.toList;
 import static org.hamcrest.Matchers.hasEntry;
 import static org.hamcrest.Matchers.hasSize;
@@ -136,6 +138,88 @@ public class SqlParserTests extends ESTestCase {
         assertThat(mmqp.optionMap(), hasEntry("fuzzy_rewrite", "scoring_boolean"));
     }
 
+    public void testLimitToPreventStackOverflowFromLargeUnaryBooleanExpression() {
+        // Create expression in the form of NOT(NOT(NOT ... (b) ...)
+
+        // 40 elements is ok
+        new SqlParser().createExpression(
+            Joiner.on("").join(nCopies(40, "NOT(")).concat("b").concat(Joiner.on("").join(nCopies(40, ")"))));
+
+        // 100 elements parser's "circuit breaker" is triggered
+        ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createExpression(
+            Joiner.on("").join(nCopies(100, "NOT(")).concat("b").concat(Joiner.on("").join(nCopies(100, ")")))));
+        assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
+    }
+
+    public void testLimitToPreventStackOverflowFromLargeBinaryBooleanExpression() {
+        // Create expression in the form of a = b OR a = b OR ... a = b
+
+        // 50 elements is ok
+        new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(50, "a = b")));
+
+        // 100 elements parser's "circuit breaker" is triggered
+        ParsingException e = expectThrows(ParsingException.class, () ->
+            new SqlParser().createExpression(Joiner.on(" OR ").join(nCopies(100, "a = b"))));
+        assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
+    }
+
+    public void testLimitToPreventStackOverflowFromLargeUnaryArithmeticExpression() {
+        // Create expression in the form of abs(abs(abs ... (i) ...)
+
+        // 50 elements is ok
+        new SqlParser().createExpression(
+            Joiner.on("").join(nCopies(50, "abs(")).concat("i").concat(Joiner.on("").join(nCopies(50, ")"))));
+
+        // 101 elements parser's "circuit breaker" is triggered
+        ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createExpression(
+            Joiner.on("").join(nCopies(101, "abs(")).concat("i").concat(Joiner.on("").join(nCopies(101, ")")))));
+        assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
+    }
+
+    public void testLimitToPreventStackOverflowFromLargeBinaryArithmeticExpression() {
+        // Create expression in the form of a + a + a + ... + a
+
+        // 100 elements is ok
+        new SqlParser().createExpression(Joiner.on(" + ").join(nCopies(100, "a")));
+
+        // 101 elements parser's "circuit breaker" is triggered
+        ParsingException e = expectThrows(ParsingException.class, () ->
+            new SqlParser().createExpression(Joiner.on(" + ").join(nCopies(101, "a"))));
+        assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
+    }
+
+    public void testLimitToPreventStackOverflowFromLargeSubselectTree() {
+        // Test with queries in the form of `SELECT * FROM (SELECT * FROM (... t) ...)
+
+        // 100 elements is ok
+        new SqlParser().createStatement(
+            Joiner.on(" (").join(nCopies(100, "SELECT * FROM"))
+                .concat("t")
+                .concat(Joiner.on("").join(nCopies(99, ")"))));
+
+        // 101 elements parser's "circuit breaker" is triggered
+        ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createStatement(
+            Joiner.on(" (").join(nCopies(101, "SELECT * FROM"))
+                .concat("t")
+                .concat(Joiner.on("").join(nCopies(100, ")")))));
+        assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
+    }
+
+    public void testLimitToPreventStackOverflowFromLargeComplexSubselectTree() {
+        // Test with queries in the form of `SELECT true OR true OR .. FROM (SELECT true OR true OR... FROM (... t) ...)
+
+        new SqlParser().createStatement(
+            Joiner.on(" (").join(nCopies(20, "SELECT ")).
+                concat(Joiner.on(" OR ").join(nCopies(50, "true"))).concat(" FROM")
+                .concat("t").concat(Joiner.on("").join(nCopies(19, ")"))));
+
+        ParsingException e = expectThrows(ParsingException.class, () -> new SqlParser().createStatement(
+            Joiner.on(" (").join(nCopies(20, "SELECT ")).
+                concat(Joiner.on(" OR ").join(nCopies(100, "true"))).concat(" FROM")
+                .concat("t").concat(Joiner.on("").join(nCopies(19, ")")))));
+        assertEquals("expression is too large to parse, (tree's depth exceeds 100)", e.getErrorMessage());
+    }
+
     private LogicalPlan parseStatement(String sql) {
         return new SqlParser().createStatement(sql);
     }