Browse Source

ES|QL: Add number of max branches for FORK (#129834)

Ioana Tagirta 3 months ago
parent
commit
11ca4f688a

+ 1 - 1
x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/ForkGenerator.java

@@ -34,7 +34,7 @@ public class ForkGenerator implements CommandGenerator {
             }
         }
 
-        int n = randomIntBetween(2, 10);
+        int n = randomIntBetween(2, 8);
 
         String cmd = " | FORK " + "( WHERE true ) ".repeat(n) + " | WHERE _fork == \"fork" + randomIntBetween(1, n) + "\" | DROP _fork";
 

+ 12 - 1
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ForkIT.java

@@ -1007,7 +1007,7 @@ public class ForkIT extends AbstractEsqlIntegTestCase {
                ( WHERE content:"fox" )
             """;
         var e = expectThrows(ParsingException.class, () -> run(query));
-        assertTrue(e.getMessage().contains("Fork requires at least two branches"));
+        assertTrue(e.getMessage().contains("Fork requires at least 2 branches"));
     }
 
     public void testForkWithinFork() {
@@ -1047,6 +1047,17 @@ public class ForkIT extends AbstractEsqlIntegTestCase {
         }
     }
 
+    public void testWithTooManySubqueries() {
+        var query = """
+            FROM test
+            | FORK (WHERE true) (WHERE true) (WHERE true) (WHERE true) (WHERE true)
+                   (WHERE true) (WHERE true) (WHERE true) (WHERE true)
+            """;
+        var e = expectThrows(ParsingException.class, () -> run(query));
+        assertTrue(e.getMessage().contains("Fork requires less than 8 branches"));
+
+    }
+
     private void createAndPopulateIndices() {
         var indexName = "test";
         var client = client().admin().indices();

+ 6 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

@@ -651,9 +651,13 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
     @SuppressWarnings("unchecked")
     public PlanFactory visitForkCommand(EsqlBaseParser.ForkCommandContext ctx) {
         List<PlanFactory> subQueries = visitForkSubQueries(ctx.forkSubQueries());
-        if (subQueries.size() < 2) {
-            throw new ParsingException(source(ctx), "Fork requires at least two branches");
+        if (subQueries.size() < Fork.MIN_BRANCHES) {
+            throw new ParsingException(source(ctx), "Fork requires at least " + Fork.MIN_BRANCHES + " branches");
         }
+        if (subQueries.size() > Fork.MAX_BRANCHES) {
+            throw new ParsingException(source(ctx), "Fork requires less than " + Fork.MAX_BRANCHES + " branches");
+        }
+
         return input -> {
             checkForRemoteClusters(input, source(ctx), "FORK");
             List<LogicalPlan> subPlans = subQueries.stream().map(planFactory -> planFactory.apply(input)).toList();

+ 8 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Fork.java

@@ -36,13 +36,19 @@ import java.util.stream.Collectors;
 public class Fork extends LogicalPlan implements PostAnalysisPlanVerificationAware, TelemetryAware {
 
     public static final String FORK_FIELD = "_fork";
+    public static final int MAX_BRANCHES = 8;
+    public static final int MIN_BRANCHES = 2;
     private final List<Attribute> output;
 
     public Fork(Source source, List<LogicalPlan> children, List<Attribute> output) {
         super(source, children);
-        if (children.size() < 2) {
-            throw new IllegalArgumentException("requires more than two subqueries, got:" + children.size());
+        if (children.size() < MIN_BRANCHES) {
+            throw new IllegalArgumentException("FORK requires more than " + MIN_BRANCHES + " branches, got: " + children.size());
         }
+        if (children.size() > MAX_BRANCHES) {
+            throw new IllegalArgumentException("FORK requires less than " + MAX_BRANCHES + " subqueries, got: " + children.size());
+        }
+
         this.output = output;
     }
 

+ 40 - 10
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java

@@ -3354,7 +3354,15 @@ public class StatementParserTests extends AbstractStatementParserTests {
                ( EVAL xyz = ( (a/b) * (b/a)) )
                ( WHERE a < 1 )
                ( KEEP a )
-               ( DROP b )
+            | KEEP a
+            """;
+
+        var plan = statement(query);
+        assertThat(plan, instanceOf(Keep.class));
+
+        query = """
+            FROM foo*
+            | FORK
                ( RENAME a as c )
                ( MV_EXPAND a )
                ( CHANGE_POINT a on b )
@@ -3365,7 +3373,7 @@ public class StatementParserTests extends AbstractStatementParserTests {
             | KEEP a
             """;
 
-        var plan = statement(query);
+        plan = statement(query);
         assertThat(plan, instanceOf(Keep.class));
     }
 
@@ -3383,7 +3391,15 @@ public class StatementParserTests extends AbstractStatementParserTests {
                ( EVAL xyz = ( (a/b) * (b/a)) )
                ( WHERE a < 1 )
                ( KEEP a )
-               ( DROP b )
+
+            | KEEP a
+            """;
+        var plan = statement(query);
+        assertThat(plan, instanceOf(Keep.class));
+
+        query = """
+            FROM foo*
+            | FORK
                ( RENAME a as c )
                ( MV_EXPAND a )
                ( CHANGE_POINT a on b )
@@ -3392,22 +3408,36 @@ public class StatementParserTests extends AbstractStatementParserTests {
                ( FORK ( WHERE a:"baz" ) ( EVAL x = [ 1, 2, 3 ] ) )
                ( COMPLETION a = b WITH c )
                ( SAMPLE 0.99 )
+            | KEEP a
+            """;
+        plan = statement(query);
+        assertThat(plan, instanceOf(Keep.class));
+
+        query = """
+            FROM foo*
+            | FORK
                ( INLINESTATS x = MIN(a), y = MAX(b) WHERE d > 1000 )
                ( INSIST_🐔 a )
                ( LOOKUP_🐔 a on b )
             | KEEP a
             """;
-
-        var plan = statement(query);
+        plan = statement(query);
         assertThat(plan, instanceOf(Keep.class));
     }
 
     public void testInvalidFork() {
-        expectError("FROM foo* | FORK (WHERE a:\"baz\")", "line 1:13: Fork requires at least two branches");
-        expectError("FROM foo* | FORK (LIMIT 10)", "line 1:13: Fork requires at least two branches");
-        expectError("FROM foo* | FORK (SORT a)", "line 1:13: Fork requires at least two branches");
-        expectError("FROM foo* | FORK (WHERE x>1 | LIMIT 5)", "line 1:13: Fork requires at least two branches");
-        expectError("FROM foo* | WHERE x>1 | FORK (WHERE a:\"baz\")", "Fork requires at least two branches");
+        expectError("FROM foo* | FORK (WHERE a:\"baz\")", "line 1:13: Fork requires at least 2 branches");
+        expectError("FROM foo* | FORK (LIMIT 10)", "line 1:13: Fork requires at least 2 branches");
+        expectError("FROM foo* | FORK (SORT a)", "line 1:13: Fork requires at least 2 branches");
+        expectError("FROM foo* | FORK (WHERE x>1 | LIMIT 5)", "line 1:13: Fork requires at least 2 branches");
+        expectError("FROM foo* | WHERE x>1 | FORK (WHERE a:\"baz\")", "Fork requires at least 2 branches");
+
+        expectError("""
+            FROM foo*
+            | FORK (where true) (where true) (where true) (where true)
+                   (where true) (where true) (where true) (where true)
+                   (where true)
+            """, "Fork requires less than 8 branches");
 
         expectError("FROM foo* | FORK ( x+1 ) ( WHERE y>2 )", "line 1:20: mismatched input 'x+1'");
         expectError("FROM foo* | FORK ( LIMIT 10 ) ( y+2 )", "line 1:33: mismatched input 'y+2'");

+ 1 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java

@@ -562,7 +562,7 @@ public class EsqlNodeSubclassTests<T extends B, B extends Node<B>> extends NodeS
 
     private static int randomSizeForCollection(Class<? extends Node<?>> toBuildClass) {
         int minCollectionLength = 0;
-        int maxCollectionLength = 10;
+        int maxCollectionLength = 8;
 
         if (hasAtLeastTwoChildren(toBuildClass)) {
             minCollectionLength = 2;