|
@@ -5,6 +5,7 @@
|
|
*/
|
|
*/
|
|
package org.elasticsearch.xpack.sql.parser;
|
|
package org.elasticsearch.xpack.sql.parser;
|
|
|
|
|
|
|
|
+import com.carrotsearch.hppc.ObjectShortHashMap;
|
|
import org.antlr.v4.runtime.BaseErrorListener;
|
|
import org.antlr.v4.runtime.BaseErrorListener;
|
|
import org.antlr.v4.runtime.CharStream;
|
|
import org.antlr.v4.runtime.CharStream;
|
|
import org.antlr.v4.runtime.CommonToken;
|
|
import org.antlr.v4.runtime.CommonToken;
|
|
@@ -22,8 +23,8 @@ import org.antlr.v4.runtime.atn.PredictionMode;
|
|
import org.antlr.v4.runtime.dfa.DFA;
|
|
import org.antlr.v4.runtime.dfa.DFA;
|
|
import org.antlr.v4.runtime.misc.Pair;
|
|
import org.antlr.v4.runtime.misc.Pair;
|
|
import org.antlr.v4.runtime.tree.TerminalNode;
|
|
import org.antlr.v4.runtime.tree.TerminalNode;
|
|
|
|
+import org.apache.logging.log4j.LogManager;
|
|
import org.apache.logging.log4j.Logger;
|
|
import org.apache.logging.log4j.Logger;
|
|
-import org.elasticsearch.common.logging.Loggers;
|
|
|
|
import org.elasticsearch.xpack.sql.expression.Expression;
|
|
import org.elasticsearch.xpack.sql.expression.Expression;
|
|
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
|
|
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
|
|
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
|
|
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
|
|
@@ -41,7 +42,8 @@ import java.util.function.Function;
|
|
import static java.lang.String.format;
|
|
import static java.lang.String.format;
|
|
|
|
|
|
public class SqlParser {
|
|
public class SqlParser {
|
|
- private static final Logger log = Loggers.getLogger(SqlParser.class);
|
|
|
|
|
|
+
|
|
|
|
+ private static final Logger log = LogManager.getLogger();
|
|
|
|
|
|
private final boolean DEBUG = false;
|
|
private final boolean DEBUG = false;
|
|
|
|
|
|
@@ -83,7 +85,9 @@ public class SqlParser {
|
|
return invokeParser(expression, params, SqlBaseParser::singleExpression, AstBuilder::expression);
|
|
return invokeParser(expression, params, SqlBaseParser::singleExpression, AstBuilder::expression);
|
|
}
|
|
}
|
|
|
|
|
|
- private <T> T invokeParser(String sql, List<SqlTypedParamValue> params, Function<SqlBaseParser, ParserRuleContext> parseFunction,
|
|
|
|
|
|
+ private <T> T invokeParser(String sql,
|
|
|
|
+ List<SqlTypedParamValue> params, Function<SqlBaseParser,
|
|
|
|
+ ParserRuleContext> parseFunction,
|
|
BiFunction<AstBuilder, ParserRuleContext, T> visitor) {
|
|
BiFunction<AstBuilder, ParserRuleContext, T> visitor) {
|
|
SqlBaseLexer lexer = new SqlBaseLexer(new CaseInsensitiveStream(sql));
|
|
SqlBaseLexer lexer = new SqlBaseLexer(new CaseInsensitiveStream(sql));
|
|
|
|
|
|
@@ -96,6 +100,7 @@ public class SqlParser {
|
|
CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
|
|
CommonTokenStream tokenStream = new CommonTokenStream(tokenSource);
|
|
SqlBaseParser parser = new SqlBaseParser(tokenStream);
|
|
SqlBaseParser parser = new SqlBaseParser(tokenStream);
|
|
|
|
|
|
|
|
+ parser.addParseListener(new CircuitBreakerListener());
|
|
parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames())));
|
|
parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames())));
|
|
|
|
|
|
parser.removeErrorListeners();
|
|
parser.removeErrorListeners();
|
|
@@ -125,7 +130,7 @@ public class SqlParser {
|
|
return visitor.apply(new AstBuilder(paramTokens), tree);
|
|
return visitor.apply(new AstBuilder(paramTokens), tree);
|
|
}
|
|
}
|
|
|
|
|
|
- private void debug(SqlBaseParser parser) {
|
|
|
|
|
|
+ private static void debug(SqlBaseParser parser) {
|
|
|
|
|
|
// when debugging, use the exact prediction mode (needed for diagnostics as well)
|
|
// when debugging, use the exact prediction mode (needed for diagnostics as well)
|
|
parser.getInterpreter().setPredictionMode(PredictionMode.LL_EXACT_AMBIG_DETECTION);
|
|
parser.getInterpreter().setPredictionMode(PredictionMode.LL_EXACT_AMBIG_DETECTION);
|
|
@@ -154,7 +159,7 @@ public class SqlParser {
|
|
public void exitBackQuotedIdentifier(SqlBaseParser.BackQuotedIdentifierContext context) {
|
|
public void exitBackQuotedIdentifier(SqlBaseParser.BackQuotedIdentifierContext context) {
|
|
Token token = context.BACKQUOTED_IDENTIFIER().getSymbol();
|
|
Token token = context.BACKQUOTED_IDENTIFIER().getSymbol();
|
|
throw new ParsingException(
|
|
throw new ParsingException(
|
|
- "backquoted indetifiers not supported; please use double quotes instead",
|
|
|
|
|
|
+ "backquoted identifiers not supported; please use double quotes instead",
|
|
null,
|
|
null,
|
|
token.getLine(),
|
|
token.getLine(),
|
|
token.getCharPositionInLine());
|
|
token.getCharPositionInLine());
|
|
@@ -205,6 +210,35 @@ public class SqlParser {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ /**
|
|
|
|
+ * Used to catch large expressions that can lead to stack overflows
|
|
|
|
+ */
|
|
|
|
+ private class CircuitBreakerListener extends SqlBaseBaseListener {
|
|
|
|
+
|
|
|
|
+ private static final short MAX_RULE_DEPTH = 100;
|
|
|
|
+
|
|
|
|
+ // Keep current depth for every rule visited.
|
|
|
|
+ // The totalDepth alone cannot be used as expressions like: e1 OR e2 OR e3 OR ...
|
|
|
|
+ // are processed as e1 OR (e2 OR (e3 OR (... and this results in the totalDepth not growing
|
|
|
|
+ // while the stack call depth is, leading to a StackOverflowError.
|
|
|
|
+ private ObjectShortHashMap<String> depthCounts = new ObjectShortHashMap<>();
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void enterEveryRule(ParserRuleContext ctx) {
|
|
|
|
+ short currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 1, (short) 1);
|
|
|
|
+ if (currentDepth > MAX_RULE_DEPTH) {
|
|
|
|
+ throw new ParsingException("expression is too large to parse, (tree's depth exceeds {})", MAX_RULE_DEPTH);
|
|
|
|
+ }
|
|
|
|
+ super.enterEveryRule(ctx);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void exitEveryRule(ParserRuleContext ctx) {
|
|
|
|
+ depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 0, (short) -1);
|
|
|
|
+ super.exitEveryRule(ctx);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
|
|
private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
|
|
@Override
|
|
@Override
|
|
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line,
|
|
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line,
|