transcendental.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import math, functools
  2. from typing import Tuple, List
  3. from tinygrad.dtype import dtypes, DType
  4. from tinygrad.codegen.uops import UOp
  5. TRANSCENDENTAL_SUPPORTED_DTYPES = {dtypes.float16, dtypes.float32, dtypes.float64}
  6. def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
  7. """replace inf -> inf, -inf -> _inf, nan -> nan, otherwise -> ratio"""
  8. return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
  9. # *** helper functions for double/quad precision arithmetics ***
  10. def dfadd2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx + yx, xy + yy
  11. def dfmul2_f2_f2_f2(xx:UOp, xy:UOp, yx:UOp, yy:UOp) -> Tuple[UOp, UOp]: return xx * yx, xx * yy + xy * yx
  12. def dfdiv2_f2_f2_f2(nx:UOp, ny:UOp, dx:UOp, dy:UOp) -> Tuple[UOp, UOp]:
  13. t = dx.recip()
  14. qx = nx * t
  15. qy = (ny - qx * dy) * t
  16. return qx, qy
  17. # *** helper functions for bit manipulation ***
  18. def significand_bits(d:DType) -> int: return {dtypes.float64: 52, dtypes.float32: 23, dtypes.float16: 10}[d]
  19. def exponent_bias(d:DType) -> int: return {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 14}[d]
  20. def exponent_mask(d:DType) -> int: return {dtypes.float64: 0x7FF, dtypes.float32: 0xFF, dtypes.float16: 0x1F}[d]
  21. def float_to_bits(d:UOp) -> UOp:
  22. assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  23. cast_to = {dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[d.dtype]
  24. return d.bitcast(cast_to)
  25. def bits_to_float(d:UOp, float_dtype:DType) -> UOp:
  26. assert d.dtype in [dtypes.uint64, dtypes.uint32, dtypes.uint16]
  27. cast_to = {dtypes.uint64: dtypes.float64, dtypes.uint32: dtypes.float32, dtypes.uint16: float_dtype}[d.dtype]
  28. return d.bitcast(cast_to)
  29. # **** utils ****
  30. def shr(x:UOp, y:int) -> UOp: return x // (2**y)
  31. def shl(x:UOp, y:int) -> UOp: return x * (2**y)
  32. def rintk(d:UOp) -> UOp:
  33. """ceiling(d:float) -> int"""
  34. assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  35. return_t = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
  36. return (d + d.lt(0.0).where(d.const(-0.5), d.const(0.5))).cast(return_t)
  37. def pow2if(q:UOp, float_dtype:DType):
  38. """cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
  39. assert q.dtype in (dtypes.int64, dtypes.int32, dtypes.int16, dtypes.uint32)
  40. final_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype, dtypes.uint32: dtypes.float32}[q.dtype]
  41. return shl((q + (exponent_bias(final_dtype)+1)), significand_bits(final_dtype)).bitcast(final_dtype)
  42. def ilogb2k(d:UOp) -> UOp:
  43. """calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
  44. assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  45. dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
  46. # -1 <= ilog2bk(d) <= 128
  47. # ((float_to_bits(d) >> significand_bits(dtype)) & exponent_mask(dtype)) - exponent_bias(dtype)
  48. return (shr(dint, significand_bits(d.dtype)) & exponent_mask(d.dtype)) - (exponent_bias(d.dtype)+1)
  49. def ldexp3k(d:UOp, e:UOp) -> UOp:
  50. """d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
  51. assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  52. dtype = d.dtype
  53. cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
  54. e = e.cast(cast_map[d.dtype])
  55. m1 = d.bitcast(cast_map[d.dtype])
  56. m2 = shl(e, significand_bits(d.dtype))
  57. return (m1 + m2).bitcast(d.dtype).cast(dtype)
  58. def ldexp2k(d:UOp, e:UOp) -> UOp:
  59. """d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
  60. assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
  61. return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
  62. def frexp(v:UOp) -> Tuple[UOp, UOp]:
  63. """frexp(v) -> (mantissa, exponent)"""
  64. assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  65. # m1 = masks for mantissa, m2 = masks to normalize the mantissa.
  66. m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
  67. m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3C00}[v.dtype]
  68. bias = {dtypes.float64: 1022, dtypes.float32: 126, dtypes.float16: 15}[v.dtype]
  69. bits = float_to_bits(v)
  70. exponent = shr(bits, significand_bits(v.dtype)) & exponent_mask(v.dtype)
  71. exponent_zero = exponent.ne(0.0)
  72. result_f = bits_to_float((bits & m1) | m2, v.dtype)
  73. value = exponent_zero.where(result_f, v)
  74. exp = exponent + (-bias)
  75. exp = exponent_zero.where(exp, exp.const(0))
  76. if v.dtype == dtypes.float16: exp = exp.bitcast(dtypes.int16)
  77. return value, exp
  78. def mla(x:UOp, y:UOp, z:UOp) -> UOp: return x * y + z
  79. def polyN(u:UOp, s:UOp, coeffs:List[float]) -> UOp: return functools.reduce(lambda u,c: mla(u, s, u.const(c)), coeffs, u)
  80. # *** reduction algorithms for sine ***
  81. def payne_hanek_reduction(d:UOp) -> Tuple[UOp, UOp]:
  82. """
  83. Performs Payne-Hanek Reduction: computes the remainder of `d` modulo pi/2 for the values `d` where
  84. 39800.0 <= d <= +Inf
  85. Returns a tuple of `(r, q)`:
  86. - `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
  87. ensuring that `r` is in the range of [0, pi/2).
  88. - `q`[int32] is an integer taking values 0,1,2 or 3, corresponding to the quadrant of the original angle `d`.
  89. """
  90. assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  91. two_over_pi_f = [0x00000000,0x28be60db,0x9391054a,0x7f09d5f4,0x7d4d3770,0x36d8a566,0x4f10e410]
  92. input_dtype: DType = d.dtype
  93. dtype_via = dtypes.float32 if d.dtype == dtypes.float16 else d.dtype
  94. acc_dtype = dtypes.uint64
  95. f, e = frexp(d)
  96. ia = (f.cast(dtype_via) * 4.294967296e9).cast(dtypes.uint64)
  97. i = shr(e.cast(dtypes.uint64), 5)
  98. e = (e.cast(dtypes.uint64) & 31).cast(dtypes.uint32)
  99. offset = -e + 32
  100. def _eq(arr:UOp, eq_to:int) -> UOp: return arr.ne(eq_to)
  101. def _take(an:UOp, offset:int, count:int=0) -> UOp:
  102. """an = two_over_pi_f[i+offset]"""
  103. if count+offset <= len(two_over_pi_f[0:-2]):
  104. an = _eq(i, count).where(_take(an, offset, count=count+1), an.const(two_over_pi_f[count+offset]))
  105. return an
  106. def _exact_pow2if(x): return pow2if(x, input_dtype).cast(acc_dtype)
  107. def _shl_lazy(x, y): return (x.cast(acc_dtype) * _exact_pow2if(y)).cast(dtypes.uint32)
  108. def _shr_lazy(x, y): return (x.cast(acc_dtype) // _exact_pow2if(y)).cast(dtypes.uint32)
  109. # a_n = (two_over_pi_f[Int(i) + n] << e) | (two_over_pi_f[Int(i) + n+1] >> (nbits - e))
  110. a1 = _take(i.const(0).cast(dtypes.uint32), 0)
  111. a2 = _take(i.const(0).cast(dtypes.uint32), 1)
  112. a3 = _take(i.const(0).cast(dtypes.uint32), 2)
  113. a4 = _take(i.const(0).cast(dtypes.uint32), 3)
  114. # Note: e >= 1 for all numbers d >= 1.0. assume e != 0
  115. hi = _shl_lazy(a1, e) | _shr_lazy(a2, offset)
  116. mi = _shl_lazy(a2, e) | _shr_lazy(a3, offset)
  117. lo = _shl_lazy(a3, e) | _shr_lazy(a4, offset)
  118. def _hp_mul(x:UOp, y:UOp) -> UOp: return x.cast(dtypes.uint64) * y.cast(dtypes.uint64)
  119. p = _hp_mul(ia, lo)
  120. p = _hp_mul(ia, mi) + shr(p, 32)
  121. p = shl(_hp_mul(ia, hi), 32) + p
  122. q = shr(p, 62).cast(dtypes.int32)
  123. p = p & 0x3fffffffffffffff
  124. r = (p.cast(dtype_via) * (3.4061215800865545e-19)).cast(input_dtype)
  125. # if fraction >= 0.5, r -= pi/2, q += 1
  126. return f.lt(0.5).where(r, r + r.const(-math.pi / 2)), f.lt(0.5).where(q, q + 1)
  127. def cody_waite_reduction(d:UOp) -> Tuple[UOp, UOp]:
  128. """
  129. Performs Cody-Waite Reduction: computes the reminder of `d` modulo pi/2 for the values `d` where
  130. 0 <= abs(d) <= 39800.0
  131. Returns a tuple of `(r, q)`, where the output format is the same as that of `payne_hanek_reduction`.
  132. """
  133. m_1_pi = 0.318309886183790671537767526745028724
  134. qdh = (d * (m_1_pi / 16777216)).cast(dtypes.int64).cast(d.dtype) * 16777216.0
  135. def _quadrant(x:UOp) -> UOp:
  136. if x.dtype == dtypes.float64: return rintk(mla(d, d.const(m_1_pi), -qdh)).cast(x.dtype)
  137. return rintk(x * m_1_pi).cast(x.dtype)
  138. def _reduce_d(x:UOp, q:UOp):
  139. if x.dtype == dtypes.float64:
  140. d = mla(qdh, x.const(-3.1415926218032836914), x)
  141. d = mla(q, x.const(-3.1415926218032836914), d)
  142. d = mla(qdh, x.const(-3.1786509424591713469e-08), d)
  143. d = mla(q, x.const(-3.1786509424591713469e-08), d)
  144. d = mla(qdh, x.const(-1.2246467864107188502e-16), d)
  145. d = mla(q, x.const(-1.2246467864107188502e-16), d)
  146. d = mla(qdh + q, x.const(-1.2736634327021899816e-24), d)
  147. elif x.dtype == dtypes.float16:
  148. # [FIXME] when reducing `d`, FP16 needs FP32 precision to achieve 1.0 ULP precision.
  149. d = _reduce_d(x.cast(dtypes.float32), q.cast(dtypes.float32)).cast(dtypes.float16)
  150. else:
  151. d = mla(q, x.const(-3.1414794921875), x)
  152. d = mla(q, x.const(-0.00011315941810607910156), d)
  153. d = mla(q, x.const(-1.9841872589410058936e-09), d)
  154. d = mla(q, x.const(-1.2154201256553420762e-10), d)
  155. return d
  156. return _reduce_d(d, (q := _quadrant(d))), q.cast(dtypes.int32)
  157. # *** approximate sine on small angle. ***
  158. def trig_poly(d:UOp, coeff32, coeff64):
  159. u = None
  160. s = d * d
  161. if d.dtype == dtypes.float64:
  162. s2 = s * s
  163. s4 = s2 * s2
  164. def __poly4(x:UOp, x2:UOp, c3, c2, c1, c0) -> UOp: return mla(x2, mla(x, x.const(c3), x.const(c2)), mla(x, x.const(c1), x.const(c0)))
  165. def __poly8(x, x2, x4, c7, c6, c5, c4, c3, c2, c1, c0) -> UOp: return mla(x4, __poly4(x, x2, c7, c6, c5, c4), __poly4(x, x2, c3, c2, c1, c0))
  166. u = __poly8(s, s2, s4, *coeff64[:-1])
  167. u = mla(u, s, d.const(coeff64[-1]))
  168. else:
  169. u = polyN(s.const(coeff32[0]), s, coeff32[1:])
  170. return mla(s, u * d, d)
  171. # approximate sine on [-pi/2, pi/2]
  172. def sin_poly(d:UOp) -> UOp: return trig_poly(d, [2.6083159809786593541503e-06, -0.0001981069071916863322258, 0.00833307858556509017944336, -0.166666597127914428710938], [-7.97255955009037868891952e-18, 2.81009972710863200091251e-15, -7.64712219118158833288484e-13, 1.60590430605664501629054e-10, -2.50521083763502045810755e-08, 2.75573192239198747630416e-06, -0.000198412698412696162806809, 0.00833333333333332974823815, -0.166666666666666657414808]) # noqa: E501
  173. def sin_poly_small(d:UOp, q:UOp) -> UOp:
  174. def _ifand(n:int): return (q & n).ne(0)
  175. r = sin_poly(d)
  176. return r * _ifand(1).where(r.const(-1), r.const(1))
  177. def sin_poly_large(d:UOp, q:UOp) -> UOp:
  178. def _ifand(n:int): return (q & n).ne(0)
  179. d = d + _ifand(1).where(d.const(math.pi / 2), d.const(0))
  180. r = sin_poly(d)
  181. return r * _ifand(2).where(r.const(-1), r.const(1))
  182. # *** toplevel functions for xsin/xlog2/xexp2 ***
  183. def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
  184. """
  185. Implements a 1.0 ULP approximation for UnaryOps.SIN.
  186. - fast=True assumes x <= switch_over.
  187. - switch_over is the threshold for switching to payne_hanek_reduction.
  188. """
  189. assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  190. reduction_algo = cody_waite_reduction if fast else payne_hanek_reduction
  191. # mask +-inf/nan as zero
  192. x = _lazy_map_numbers(d, d.const(0.0), d.const(0.0), d.const(0.0), d)
  193. # x_sign = sign(x)
  194. x_sign = x.ne(0).where(x.lt(0).where(x.const(-1), x.const(1)), x.const(0))
  195. x_abs = x * x_sign
  196. r, q = reduction_algo(x_abs)
  197. if fast: result = sin_poly_small(r, q)
  198. else:
  199. # Payne Hanek Reduction assumes abs(x) >= pi/4, so for smaller values, use cody_waite_reduction.
  200. switch_over_map = x_abs.lt(switch_over)
  201. r_fast, q_fast = cody_waite_reduction(x_abs)
  202. r = switch_over_map.where(r_fast, r)
  203. q = switch_over_map.where(q_fast, q)
  204. result = switch_over_map.where(sin_poly_small(r, q), sin_poly_large(r, q))
  205. result = result * x_sign # adjusts the sign for abs(x).
  206. # sin(Inf) = NaN, sin(-Inf) = NaN, sin(NaN) = NaN
  207. return _lazy_map_numbers(d, d.const(math.nan), d.const(math.nan), d.const(math.nan), result)
  208. def xexp2(x:UOp) -> UOp:
  209. """
  210. Implements a 1.0 ULP approximation for UnaryOps.EXP2
  211. - Paper: https://arxiv.org/pdf/2001.09258
  212. """
  213. assert x.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  214. fp64_p = x.dtype == dtypes.float64
  215. # mask +=inf/nan as zero.
  216. d = _lazy_map_numbers(x, x.const(0.0), x.const(0.0), x.const(0.0), x)
  217. q = rintk(d)
  218. # s = d - round(d)
  219. s = d - q.cast(d.dtype)
  220. # a polynomial approximation with 13 non-zero terms in the range of [−(log 2)/2,(log 2)/2].
  221. if fp64_p:
  222. u = polyN(s.const(0.4434359082926529454e-9), s, [0.7073164598085707425e-8, 0.1017819260921760451e-6, 0.1321543872511327615e-5, 0.1525273353517584730e-4, 0.1540353045101147808e-3, 0.1333355814670499073e-2, 0.9618129107597600536e-2, 0.5550410866482046596e-1, 0.2402265069591012214e+0, 0.6931471805599452862e+0, 0.1000000000000000000e+1]) # noqa: E501
  223. else:
  224. u = polyN(s.const(0.1535920892e-3), s, [0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 0.1000000000e+1])
  225. u = ldexp2k(u, q) # u*2^q
  226. upper = {dtypes.float64: 1024, dtypes.float32: 128, dtypes.float16: 23.0}[d.dtype]
  227. lower = {dtypes.float64: -2000, dtypes.float32: -150, dtypes.float16: -22}[d.dtype]
  228. # Replace x >= upper with +inf
  229. u = d.ne(upper).where(u, d.const(math.inf))
  230. u = d.lt(upper).where(u, d.const(math.inf))
  231. # Replace x <= lower with zero.
  232. u = d.lt(lower).where(d.const(0.0), u)
  233. # x=NaN never satisfies x < Inf. (for fastmode)
  234. u = d.lt(math.inf).where(u, u.const(math.nan))
  235. # exp2(Inf) = Inf, exp2(-Inf) = 0, exp2(NaN) = NaN
  236. return _lazy_map_numbers(x, x.const(math.inf), x.const(0.0), x.const(math.nan), u)
  237. def xlog2(d:UOp) -> UOp:
  238. """
  239. Implements a 1.0 ULP approximation for UnaryOps.LOG2
  240. Paper: https://arxiv.org/pdf/2001.09258
  241. """
  242. assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
  243. fp64_p = d.dtype == dtypes.float64
  244. FLT_MIN = d.const(1e-6 if d.dtype == dtypes.float16 else 1e-4)
  245. d_orig = d
  246. denormal_map = d.lt(FLT_MIN)
  247. for _ in range(8): d = denormal_map.where(d * (2 ** 8), d)
  248. e = ilogb2k(d * (1.0 / 0.75)).cast(d.dtype)
  249. m = ldexp3k(d, -e)
  250. e = denormal_map.where(e + (-64), e)
  251. if fp64_p:
  252. x = (m - 1.0) * (m + 1.0).recip()
  253. x2 = x * x
  254. t = polyN(x.const(0.2211941750456081490e+0), x2, [0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0, 0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449]) # noqa: E501
  255. s_hi, s_lo = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(t.const(2.885390081777926774), t.const(0), x, x.const(0)))
  256. r = mla(t, x * x2, s_hi + s_lo)
  257. else:
  258. xx, xy = dfdiv2_f2_f2_f2(*dfadd2_f2_f2_f2(m.const(-1), m.const(0), m, m.const(0)), *dfadd2_f2_f2_f2(m.const(1), m.const(0), m, m.const(0)))
  259. x2 = xx * xx
  260. t = polyN(d.const(0.4374550283e+0), x2, [0.5764790177e+0, 0.9618012905120])
  261. sx, sy = dfadd2_f2_f2_f2(e, e.const(0), *dfmul2_f2_f2_f2(xx, xy, xx.const(2.8853900432586669922), xy.const(3.2734474483568488616e-08)))
  262. sx, sy = dfadd2_f2_f2_f2(sx, sy, x2.const(0), (x2 * xx) * t)
  263. r = sx + sy
  264. # log2(Inf) = Inf
  265. r = d_orig.ne(math.inf).where(r, r.const(math.inf))
  266. # log2(x=-0.01) = NaN. where x < 0
  267. r = d_orig.lt(-0.0).where(r.const(math.nan), r)
  268. # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
  269. # log2_zero = the value of unmasked xlog2(0.0).
  270. log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79, None: -math.inf}[d.dtype]
  271. r = r.ne(log2_zero).where(r, r.const(-math.inf))
  272. # log(NaN) = NaN, using for all real number x, either of x < Inf, x == Inf becomes True.
  273. r = d_orig.lt(math.inf).where(r, d_orig.ne(math.inf).where(d.const(math.nan), d))
  274. # log(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
  275. return d_orig.recip().ne(-math.inf).where(r, r.const(-math.inf))