Pārlūkot izejas kodu

Fix Locate function optional parameter handling (#49666)

Andrei Stefan 5 gadi atpakaļ
vecāks
revīzija
dd3aeb8f54

+ 20 - 0
x-pack/plugin/sql/qa/src/main/resources/functions.csv-spec

@@ -333,6 +333,26 @@ SELECT LOCATE('a',"first_name") pos, INSERT("first_name",LOCATE('a',"first_name"
 8              |ChirstiAAAn  
 ;
 
+selectLocateWithConditional1
+SELECT CAST(LOCATE('a', CASE WHEN TRUNCATE(salary, 3) > 40000 THEN first_name.keyword ELSE last_name.keyword END) > 0 AS STRING) AS x, COUNT(*) AS c FROM test_emp GROUP BY x ORDER BY c ASC;
+
+      x:s      |       c:l       
+---------------+---------------
+null           |4              
+false          |43             
+true           |53             
+;
+
+selectLocateWithConditional2
+SELECT CAST(LOCATE(CASE WHEN languages > 3 THEN 'a' ELSE 'b' END, CASE WHEN TRUNCATE(salary, 3) > 40000 THEN first_name.keyword ELSE last_name.keyword END, CASE WHEN gender IS NOT NULL THEN 3 ELSE NULL END) > 0 AS STRING) AS x, COUNT(*) AS c FROM test_emp GROUP BY x ORDER BY c DESC;
+
+      x:s      |       c:l       
+---------------+---------------
+false          |80             
+true           |16             
+null           |4              
+;
+
 selectLeft
 SELECT LEFT("first_name",2) f FROM "test_emp" ORDER BY "first_name" LIMIT 10;
 

+ 3 - 0
x-pack/plugin/sql/qa/src/main/resources/string-functions.sql-spec

@@ -253,6 +253,9 @@ SELECT LOCATE('a',"first_name",7) pos, INSERT("first_name",LOCATE('a',"first_nam
 selectLocateAndInsertWithLocateWithConditionAndTwoParameters
 SELECT LOCATE('a',"first_name") pos, INSERT("first_name",LOCATE('a',"first_name"),1,'AAA') inserted FROM "test_emp" WHERE LOCATE('a',"first_name") > 0 ORDER BY "first_name" LIMIT 10;
 
+selectLocateWithConditional
+SELECT LOCATE(CASE WHEN FALSE THEN NULL ELSE 'x' END, "first_name") > 0 AS x, COUNT(*) AS c FROM "test_emp" GROUP BY x ORDER BY c ASC;
+
 selectLeft
 SELECT LEFT("first_name",2) f FROM "test_emp" ORDER BY "first_name" NULLS LAST LIMIT 10;
 

+ 4 - 2
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/string/Locate.java

@@ -133,10 +133,12 @@ public class Locate extends ScalarFunction {
 
     @Override
     public Expression replaceChildren(List<Expression> newChildren) {
-        if (newChildren.size() != 3) {
+        if (start != null && newChildren.size() != 3) {
             throw new IllegalArgumentException("expected [3] children but received [" + newChildren.size() + "]");
+        } else if (start == null && newChildren.size() != 2) {
+            throw new IllegalArgumentException("expected [2] children but received [" + newChildren.size() + "]");
         }
 
-        return new Locate(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2));
+        return new Locate(source(), newChildren.get(0), newChildren.get(1), start == null ? null : newChildren.get(2));
     }
 }

+ 5 - 6
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/function/scalar/string/LocateFunctionPipeTests.java

@@ -75,24 +75,23 @@ public class LocateFunctionPipeTests extends AbstractNodeTestCase<LocateFunction
         LocateFunctionPipe b = randomInstance();
         Pipe newPattern = pipe(((Expression) randomValueOtherThan(b.pattern(), () -> randomStringLiteral())));
         Pipe newSource = pipe(((Expression) randomValueOtherThan(b.source(), () -> randomStringLiteral())));
-        Pipe newStart;
+        Pipe newStart = b.start() == null ? null : pipe(((Expression) randomValueOtherThan(b.start(), () -> randomIntLiteral())));
         
-        LocateFunctionPipe newB = new LocateFunctionPipe(
-                b.source(), b.expression(), b.pattern(), b.src(), b.start());
-        newStart = pipe(((Expression) randomValueOtherThan(b.start(), () -> randomIntLiteral())));
+        LocateFunctionPipe newB = new LocateFunctionPipe(b.source(), b.expression(), b.pattern(), b.src(), b.start());
         LocateFunctionPipe transformed = null;
         
         // generate all the combinations of possible children modifications and test all of them
         for(int i = 1; i < 4; i++) {
             for(BitSet comb : new Combinations(3, i)) {
+                Pipe tempNewStart = b.start() == null ? b.start() : (comb.get(2) ? newStart : b.start());
                 transformed = (LocateFunctionPipe) newB.replaceChildren(
                         comb.get(0) ? newPattern : b.pattern(),
                         comb.get(1) ? newSource : b.src(),
-                        comb.get(2) ? newStart : b.start());
+                        tempNewStart);
                 
                 assertEquals(transformed.pattern(), comb.get(0) ? newPattern : b.pattern());
                 assertEquals(transformed.src(), comb.get(1) ? newSource : b.src());
-                assertEquals(transformed.start(), comb.get(2) ? newStart : b.start());
+                assertEquals(transformed.start(), tempNewStart);
                 assertEquals(transformed.expression(), b.expression());
                 assertEquals(transformed.source(), b.source());
             }