فهرست منبع

SQL: Fix MOD() for long and integer arguments (#36599)

Previously, Math.floorMod was used for integers and longs
which has different logic for negative numbers. Also, the
priority of data types check was wrong as if one of the args
is double the evaluation should be with doubles, then for floats,
then longs and finally integers.

Fixes: #36364
Marios Trivyzas 6 سال پیش
والد
کامیت
730b154c93

+ 5 - 5
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/Arithmetics.java

@@ -124,17 +124,17 @@ public abstract class Arithmetics {
             return null;
             return null;
         }
         }
 
 
-        if (l instanceof Long || r instanceof Long) {
-            return Long.valueOf(Math.floorMod(l.longValue(), r.longValue()));
-        }
         if (l instanceof Double || r instanceof Double) {
         if (l instanceof Double || r instanceof Double) {
             return Double.valueOf(l.doubleValue() % r.doubleValue());
             return Double.valueOf(l.doubleValue() % r.doubleValue());
         }
         }
         if (l instanceof Float || r instanceof Float) {
         if (l instanceof Float || r instanceof Float) {
             return Float.valueOf(l.floatValue() % r.floatValue());
             return Float.valueOf(l.floatValue() % r.floatValue());
         }
         }
+        if (l instanceof Long || r instanceof Long) {
+            return Long.valueOf(l.longValue() % r.longValue());
+        }
 
 
-        return Math.floorMod(l.intValue(), r.intValue());
+        return l.intValue() % r.intValue();
     }
     }
 
 
     static Number negate(Number n) {
     static Number negate(Number n) {
@@ -162,4 +162,4 @@ public abstract class Arithmetics {
 
 
         return Integer.valueOf(Math.negateExact(n.intValue()));
         return Integer.valueOf(Math.negateExact(n.intValue()));
     }
     }
-}
+}

+ 38 - 15
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/expression/predicate/operator/arithmetic/BinaryArithmeticTests.java

@@ -19,6 +19,7 @@ import java.time.Period;
 import java.time.ZonedDateTime;
 import java.time.ZonedDateTime;
 import java.time.temporal.TemporalAmount;
 import java.time.temporal.TemporalAmount;
 
 
+import static org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.Arithmetics.mod;
 import static org.elasticsearch.xpack.sql.tree.Location.EMPTY;
 import static org.elasticsearch.xpack.sql.tree.Location.EMPTY;
 import static org.elasticsearch.xpack.sql.type.DataType.INTERVAL_DAY;
 import static org.elasticsearch.xpack.sql.type.DataType.INTERVAL_DAY;
 import static org.elasticsearch.xpack.sql.type.DataType.INTERVAL_DAY_TO_HOUR;
 import static org.elasticsearch.xpack.sql.type.DataType.INTERVAL_DAY_TO_HOUR;
@@ -29,32 +30,54 @@ import static org.elasticsearch.xpack.sql.type.DataType.INTERVAL_YEAR_TO_MONTH;
 
 
 public class BinaryArithmeticTests extends ESTestCase {
 public class BinaryArithmeticTests extends ESTestCase {
 
 
-    public void testAddNumbers() throws Exception {
+    public void testAddNumbers() {
         assertEquals(Long.valueOf(3), add(1L, 2L));
         assertEquals(Long.valueOf(3), add(1L, 2L));
     }
     }
 
 
-    public void testAddYearMonthIntervals() throws Exception {
+    public void testMod() {
+        assertEquals(2, mod(10, 8));
+        assertEquals(2, mod(10, -8));
+        assertEquals(-2, mod(-10, 8));
+        assertEquals(-2, mod(-10, -8));
+
+        assertEquals(2L, mod(10L, 8));
+        assertEquals(2L, mod(10, -8L));
+        assertEquals(-2L, mod(-10L, 8L));
+        assertEquals(-2L, mod(-10L, -8L));
+
+        assertEquals(2.3000002f, mod(10.3f, 8L));
+        assertEquals(1.5f, mod(10, -8.5f));
+        assertEquals(-1.8000002f, mod(-10.3f, 8.5f));
+        assertEquals(-1.8000002f, mod(-10.3f, -8.5f));
+
+        assertEquals(2.3000000000000007d, mod(10.3d, 8L));
+        assertEquals(1.5d, mod(10, -8.5d));
+        assertEquals(-1.8000001907348633d, mod(-10.3f, 8.5d));
+        assertEquals(-1.8000000000000007, mod(-10.3d, -8.5d));
+    }
+
+    public void testAddYearMonthIntervals() {
         Literal l = interval(Period.ofYears(1), INTERVAL_YEAR);
         Literal l = interval(Period.ofYears(1), INTERVAL_YEAR);
         Literal r = interval(Period.ofMonths(2), INTERVAL_MONTH);
         Literal r = interval(Period.ofMonths(2), INTERVAL_MONTH);
         IntervalYearMonth x = add(l, r);
         IntervalYearMonth x = add(l, r);
         assertEquals(interval(Period.ofYears(1).plusMonths(2), INTERVAL_YEAR_TO_MONTH), L(x));
         assertEquals(interval(Period.ofYears(1).plusMonths(2), INTERVAL_YEAR_TO_MONTH), L(x));
     }
     }
 
 
-    public void testAddYearMonthMixedIntervals() throws Exception {
+    public void testAddYearMonthMixedIntervals() {
         Literal l = interval(Period.ofYears(1).plusMonths(5), INTERVAL_YEAR_TO_MONTH);
         Literal l = interval(Period.ofYears(1).plusMonths(5), INTERVAL_YEAR_TO_MONTH);
         Literal r = interval(Period.ofMonths(2), INTERVAL_MONTH);
         Literal r = interval(Period.ofMonths(2), INTERVAL_MONTH);
         IntervalYearMonth x = add(l, r);
         IntervalYearMonth x = add(l, r);
         assertEquals(interval(Period.ofYears(1).plusMonths(7), INTERVAL_YEAR_TO_MONTH), L(x));
         assertEquals(interval(Period.ofYears(1).plusMonths(7), INTERVAL_YEAR_TO_MONTH), L(x));
     }
     }
 
 
-    public void testAddDayTimeIntervals() throws Exception {
+    public void testAddDayTimeIntervals() {
         Literal l = interval(Duration.ofDays(1), INTERVAL_DAY);
         Literal l = interval(Duration.ofDays(1), INTERVAL_DAY);
         Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR);
         Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR);
         IntervalDayTime x = add(l, r);
         IntervalDayTime x = add(l, r);
         assertEquals(interval(Duration.ofDays(1).plusHours(2), INTERVAL_DAY_TO_HOUR), L(x));
         assertEquals(interval(Duration.ofDays(1).plusHours(2), INTERVAL_DAY_TO_HOUR), L(x));
     }
     }
 
 
-    public void testAddYearMonthIntervalToDate() throws Exception {
+    public void testAddYearMonthIntervalToDate() {
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         Literal l = L(now);
         Literal l = L(now);
         TemporalAmount t = Period.ofYears(100).plusMonths(50);
         TemporalAmount t = Period.ofYears(100).plusMonths(50);
@@ -63,7 +86,7 @@ public class BinaryArithmeticTests extends ESTestCase {
         assertEquals(L(now.plus(t)), L(x));
         assertEquals(L(now.plus(t)), L(x));
     }
     }
 
 
-    public void testAddDayTimeIntervalToDate() throws Exception {
+    public void testAddDayTimeIntervalToDate() {
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         Literal l = L(now);
         Literal l = L(now);
         TemporalAmount t = Duration.ofHours(2);
         TemporalAmount t = Duration.ofHours(2);
@@ -72,7 +95,7 @@ public class BinaryArithmeticTests extends ESTestCase {
         assertEquals(L(now.plus(t)), L(x));
         assertEquals(L(now.plus(t)), L(x));
     }
     }
 
 
-    public void testAddDayTimeIntervalToDateReverse() throws Exception {
+    public void testAddDayTimeIntervalToDateReverse() {
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         Literal l = L(now);
         Literal l = L(now);
         TemporalAmount t = Duration.ofHours(2);
         TemporalAmount t = Duration.ofHours(2);
@@ -81,27 +104,27 @@ public class BinaryArithmeticTests extends ESTestCase {
         assertEquals(L(now.plus(t)), L(x));
         assertEquals(L(now.plus(t)), L(x));
     }
     }
 
 
-    public void testAddNumberToIntervalIllegal() throws Exception {
+    public void testAddNumberToIntervalIllegal() {
         Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR);
         Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR);
         SqlIllegalArgumentException expect = expectThrows(SqlIllegalArgumentException.class, () -> add(r, L(1)));
         SqlIllegalArgumentException expect = expectThrows(SqlIllegalArgumentException.class, () -> add(r, L(1)));
         assertEquals("Cannot compute [+] between [IntervalDayTime] [Integer]", expect.getMessage());
         assertEquals("Cannot compute [+] between [IntervalDayTime] [Integer]", expect.getMessage());
     }
     }
 
 
-    public void testSubYearMonthIntervals() throws Exception {
+    public void testSubYearMonthIntervals() {
         Literal l = interval(Period.ofYears(1), INTERVAL_YEAR);
         Literal l = interval(Period.ofYears(1), INTERVAL_YEAR);
         Literal r = interval(Period.ofMonths(2), INTERVAL_MONTH);
         Literal r = interval(Period.ofMonths(2), INTERVAL_MONTH);
         IntervalYearMonth x = sub(l, r);
         IntervalYearMonth x = sub(l, r);
         assertEquals(interval(Period.ofMonths(10), INTERVAL_YEAR_TO_MONTH), L(x));
         assertEquals(interval(Period.ofMonths(10), INTERVAL_YEAR_TO_MONTH), L(x));
     }
     }
 
 
-    public void testSubDayTimeIntervals() throws Exception {
+    public void testSubDayTimeIntervals() {
         Literal l = interval(Duration.ofDays(1).plusHours(10), INTERVAL_DAY_TO_HOUR);
         Literal l = interval(Duration.ofDays(1).plusHours(10), INTERVAL_DAY_TO_HOUR);
         Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR);
         Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR);
         IntervalDayTime x = sub(l, r);
         IntervalDayTime x = sub(l, r);
         assertEquals(interval(Duration.ofDays(1).plusHours(8), INTERVAL_DAY_TO_HOUR), L(x));
         assertEquals(interval(Duration.ofDays(1).plusHours(8), INTERVAL_DAY_TO_HOUR), L(x));
     }
     }
 
 
-    public void testSubYearMonthIntervalToDate() throws Exception {
+    public void testSubYearMonthIntervalToDate() {
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         Literal l = L(now);
         Literal l = L(now);
         TemporalAmount t = Period.ofYears(100).plusMonths(50);
         TemporalAmount t = Period.ofYears(100).plusMonths(50);
@@ -110,7 +133,7 @@ public class BinaryArithmeticTests extends ESTestCase {
         assertEquals(L(now.minus(t)), L(x));
         assertEquals(L(now.minus(t)), L(x));
     }
     }
 
 
-    public void testSubYearMonthIntervalToDateIllegal() throws Exception {
+    public void testSubYearMonthIntervalToDateIllegal() {
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         Literal l = L(now);
         Literal l = L(now);
         TemporalAmount t = Period.ofYears(100).plusMonths(50);
         TemporalAmount t = Period.ofYears(100).plusMonths(50);
@@ -119,13 +142,13 @@ public class BinaryArithmeticTests extends ESTestCase {
         assertEquals("Cannot substract a date from an interval; do you mean the reverse?", ex.getMessage());
         assertEquals("Cannot substract a date from an interval; do you mean the reverse?", ex.getMessage());
     }
     }
 
 
-    public void testSubNumberFromIntervalIllegal() throws Exception {
+    public void testSubNumberFromIntervalIllegal() {
         Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR);
         Literal r = interval(Duration.ofHours(2), INTERVAL_HOUR);
         SqlIllegalArgumentException expect = expectThrows(SqlIllegalArgumentException.class, () -> sub(r, L(1)));
         SqlIllegalArgumentException expect = expectThrows(SqlIllegalArgumentException.class, () -> sub(r, L(1)));
         assertEquals("Cannot compute [-] between [IntervalDayTime] [Integer]", expect.getMessage());
         assertEquals("Cannot compute [-] between [IntervalDayTime] [Integer]", expect.getMessage());
     }
     }
 
 
-    public void testSubDayTimeIntervalToDate() throws Exception {
+    public void testSubDayTimeIntervalToDate() {
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         ZonedDateTime now = ZonedDateTime.now(DateUtils.UTC);
         Literal l = L(now);
         Literal l = L(now);
         TemporalAmount t = Duration.ofHours(2);
         TemporalAmount t = Duration.ofHours(2);
@@ -158,4 +181,4 @@ public class BinaryArithmeticTests extends ESTestCase {
                  : new IntervalDayTime((Duration) value, intervalType);
                  : new IntervalDayTime((Duration) value, intervalType);
         return Literal.of(EMPTY, i);
         return Literal.of(EMPTY, i);
     }
     }
-}
+}