Browse Source

SQL: Add range checks to interval multiplication operation (#83478)

This adds checks on the multiplication operation on intervals (with integer).
Bogdan Pintea 3 years ago
parent
commit
7c6a2a60ba

+ 6 - 0
docs/changelog/83478.yaml

@@ -0,0 +1,6 @@
+pr: 83478
+summary: Add range checks to interval multiplication operation
+area: SQL
+type: bug
+issues:
+ - 83336

+ 10 - 8
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticOperation.java

@@ -24,6 +24,8 @@ import java.time.ZonedDateTime;
 import java.time.temporal.Temporal;
 import java.util.function.BiFunction;
 
+import static org.elasticsearch.xpack.ql.type.DataTypeConverter.safeToLong;
+
 public enum SqlBinaryArithmeticOperation implements BinaryArithmeticOperation {
 
     ADD((Object l, Object r) -> {
@@ -85,17 +87,17 @@ public enum SqlBinaryArithmeticOperation implements BinaryArithmeticOperation {
         if (l instanceof Number && r instanceof Number) {
             return Arithmetics.mul((Number) l, (Number) r);
         }
-        if (l instanceof Number && r instanceof IntervalYearMonth) {
-            return ((IntervalYearMonth) r).mul(((Number) l).intValue());
+        if (l instanceof Number number && r instanceof IntervalYearMonth) {
+            return ((IntervalYearMonth) r).mul(safeToLong(number));
         }
-        if (r instanceof Number && l instanceof IntervalYearMonth) {
-            return ((IntervalYearMonth) l).mul(((Number) r).intValue());
+        if (r instanceof Number number && l instanceof IntervalYearMonth) {
+            return ((IntervalYearMonth) l).mul(safeToLong(number));
         }
-        if (l instanceof Number && r instanceof IntervalDayTime) {
-            return ((IntervalDayTime) r).mul(((Number) l).longValue());
+        if (l instanceof Number number && r instanceof IntervalDayTime) {
+            return ((IntervalDayTime) r).mul(safeToLong(number));
         }
-        if (r instanceof Number && l instanceof IntervalDayTime) {
-            return ((IntervalDayTime) l).mul(((Number) r).longValue());
+        if (r instanceof Number number && l instanceof IntervalDayTime) {
+            return ((IntervalDayTime) l).mul(safeToLong(number));
         }
 
         throw new QlIllegalArgumentException(

+ 31 - 0
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/SqlBinaryArithmeticTests.java

@@ -24,6 +24,7 @@ import java.time.temporal.TemporalAmount;
 
 import static org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic.Arithmetics.mod;
 import static org.elasticsearch.xpack.ql.tree.Source.EMPTY;
+import static org.elasticsearch.xpack.ql.util.NumericUtils.UNSIGNED_LONG_MAX;
 import static org.elasticsearch.xpack.sql.type.SqlDataTypes.INTERVAL_DAY;
 import static org.elasticsearch.xpack.sql.type.SqlDataTypes.INTERVAL_DAY_TO_HOUR;
 import static org.elasticsearch.xpack.sql.type.SqlDataTypes.INTERVAL_HOUR;
@@ -243,6 +244,36 @@ public class SqlBinaryArithmeticTests extends ESTestCase {
         assertEquals(INTERVAL_MONTH, result.dataType());
     }
 
+    public void testMulIntegerIntervalYearMonthOverflow() {
+        Literal l = interval(Period.ofYears(1).plusMonths(11), INTERVAL_YEAR);
+        ArithmeticException expect = expectThrows(ArithmeticException.class, () -> mul(l, L(Integer.MAX_VALUE)));
+        assertEquals("integer overflow", expect.getMessage());
+    }
+
+    public void testMulLongIntervalYearMonthOverflow() {
+        Literal l = interval(Period.ofYears(1), INTERVAL_YEAR);
+        QlIllegalArgumentException expect = expectThrows(QlIllegalArgumentException.class, () -> mul(l, L(Long.MAX_VALUE)));
+        assertEquals("[9223372036854775807] out of [integer] range", expect.getMessage());
+    }
+
+    public void testMulUnsignedLongIntervalYearMonthOverflow() {
+        Literal l = interval(Period.ofYears(1), INTERVAL_YEAR);
+        QlIllegalArgumentException expect = expectThrows(QlIllegalArgumentException.class, () -> mul(l, L(UNSIGNED_LONG_MAX)));
+        assertEquals("[18446744073709551615] out of [long] range", expect.getMessage());
+    }
+
+    public void testMulLongIntervalDayTimeOverflow() {
+        Literal l = interval(Duration.ofDays(1), INTERVAL_DAY);
+        ArithmeticException expect = expectThrows(ArithmeticException.class, () -> mul(l, L(Long.MAX_VALUE)));
+        assertEquals("Exceeds capacity of Duration: 796899343984252629724800000000000", expect.getMessage());
+    }
+
+    public void testMulUnsignedLongIntervalDayTimeOverflow() {
+        Literal l = interval(Duration.ofDays(1), INTERVAL_DAY);
+        QlIllegalArgumentException expect = expectThrows(QlIllegalArgumentException.class, () -> mul(l, L(UNSIGNED_LONG_MAX)));
+        assertEquals("[18446744073709551615] out of [long] range", expect.getMessage());
+    }
+
     public void testAddNullInterval() {
         Literal literal = interval(Period.ofMonths(1), INTERVAL_MONTH);
         Add result = new Add(EMPTY, L(null), literal);