test_uops.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. from typing import Optional, Tuple, Any, List
  2. import unittest, math
  3. import numpy as np
  4. from tinygrad.tensor import Tensor, _to_np_dtype
  5. from tinygrad.helpers import CI, DEBUG, getenv
  6. from tinygrad.dtype import dtypes, DType, PtrDType
  7. from tinygrad.device import Buffer, Device
  8. from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, exec_alu
  9. from tinygrad.renderer import Program
  10. from tinygrad.engine.schedule import create_schedule
  11. from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
  12. from tinygrad.codegen.uops import UOps, UOp
  13. from tinygrad.codegen.uopgraph import UOpGraph
  14. from test.helpers import is_dtype_supported
  15. def _uops_to_prg(uops_list, print_uops=False):
  16. uops = UOpGraph(uops_list)
  17. src = Device[Device.DEFAULT].renderer.render("test", uops)
  18. if print_uops: uops.print()
  19. has_local = Device[Device.DEFAULT].renderer.has_local
  20. return CompiledRunner(Program("test", src, Device.DEFAULT, [1,1,1] if has_local else None, [1,1,1] if has_local else None, uops=uops))
  21. def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
  22. uops.append(UOp(uop, dtype, tuple(src), arg))
  23. return uops[-1]
  24. def _test_single_value(vals, op, dts):
  25. uops = []
  26. output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
  27. buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
  28. buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, False)) for i,dtype in enumerate(dts)]
  29. loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
  30. alu = uop(uops, UOps.ALU, output_dtype, loads, op)
  31. out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
  32. buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
  33. buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
  34. prg = _uops_to_prg([out])
  35. prg.exec([buf]+buf2)
  36. ret = np.empty(1, _to_np_dtype(output_dtype))
  37. buf.copyout(ret.data)
  38. return ret[0]
  39. def _test_single_value_const(vals, op, dts):
  40. uops = []
  41. output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
  42. buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
  43. loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
  44. alu = uop(uops, UOps.ALU, output_dtype, loads, op)
  45. out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
  46. buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
  47. prg = _uops_to_prg([out])
  48. prg.exec([buf])
  49. ret = np.empty(1, _to_np_dtype(output_dtype))
  50. buf.copyout(ret.data)
  51. return ret[0]
  52. def _test_uops_result(output_dtype, uops, res):
  53. # uops = []
  54. buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
  55. # res = output_fn(uops)
  56. out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
  57. buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
  58. prg = _uops_to_prg([out], print_uops=True)
  59. prg.exec([buf])
  60. ret = np.empty(1, _to_np_dtype(output_dtype))
  61. buf.copyout(ret.data)
  62. return ret[0]
  63. class TestUOps(unittest.TestCase):
  64. def _equal(self, v1, v2):
  65. assert isinstance(v2, (float, int, bool))
  66. if isinstance(v2, float):
  67. np.testing.assert_allclose(v1, v2, rtol=2e-7)
  68. else:
  69. np.testing.assert_equal(v1, v2)
  70. def _test_uop_fxn(self, op, fxn, dts=(dtypes.float32, )):
  71. for f in [_test_single_value, _test_single_value_const]:
  72. for a in [-2.0, 0.0, 1.0]:
  73. a = dtypes.as_const(a, dts[0])
  74. self._equal(f([a], op, dts), fxn(a))
  75. def _test_bop_fxn(self, op, fxn, dts=(dtypes.float32, )*2, no_b_zero=False, no_b_neg=False):
  76. for f in [_test_single_value, _test_single_value_const]:
  77. for a in [-2.0, 0.0, 1.0]:
  78. for b in [-3.0, 1.0] + ([] if no_b_zero else [0.0]):
  79. a = dtypes.as_const(a, dts[0])
  80. b = dtypes.as_const(abs(b) if no_b_neg else b, dts[1])
  81. self._equal(f([a,b], op, dts), fxn(a,b))
  82. def _test_top_fxn(self, op, fxn, dts=(dtypes.float32, )*3):
  83. for f in [_test_single_value, _test_single_value_const]:
  84. for a in [-2.0, 0, 1]:
  85. for b in [-3.0, 3.0]:
  86. for c in [-4.0, 4.0]:
  87. a = dtypes.as_const(a, dts[0])
  88. b = dtypes.as_const(b, dts[1])
  89. c = dtypes.as_const(c, dts[2])
  90. self._equal(f([a,b,c], op, dts), fxn(a,b,c))
  91. class TestFloatUOps(TestUOps):
  92. def test_neg(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a)
  93. def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
  94. def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
  95. def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
  96. def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1/a if a != 0 else float('inf'))
  97. def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
  98. def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
  99. def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
  100. def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
  101. def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
  102. # MOD isn't tested on floats
  103. def test_where(self):
  104. self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
  105. @unittest.skipUnless(getenv("PYTHON"), "only python supports MULACC")
  106. def test_mulacc(self):
  107. self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: a*b+c, (dtypes.float, dtypes.float, dtypes.float))
  108. class TestNonFloatUOps(TestUOps):
  109. def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (dtypes.int32, ))
  110. def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
  111. def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
  112. @unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
  113. def test_shr_int32(self): self._test_bop_fxn(BinaryOps.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
  114. @unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts")
  115. def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)<<int(b), (dtypes.int32, dtypes.int32), no_b_neg=True)
  116. def test_div_int32(self):
  117. self._test_bop_fxn(BinaryOps.IDIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
  118. def test_and_int32(self): self._test_bop_fxn(BinaryOps.AND, lambda a,b: int(a)&int(b), (dtypes.int32, dtypes.int32))
  119. def test_or_int32(self): self._test_bop_fxn(BinaryOps.OR, lambda a,b: int(a)|int(b), (dtypes.int32, dtypes.int32))
  120. def test_mod_int32(self):
  121. self._test_bop_fxn(BinaryOps.MOD,
  122. lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
  123. def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), (dtypes.int32, dtypes.int32))
  124. @unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
  125. def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
  126. @unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
  127. def test_where_float16(self):
  128. self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
  129. class TestBoolUOps(TestUOps):
  130. def _test_uop_bool_fxn(self, op, fxn):
  131. for f in [_test_single_value, _test_single_value_const]:
  132. for a in [False, True]:
  133. self._equal(f([a], op, (dtypes.bool, )*1), fxn(a))
  134. def _test_bop_bool_fxn(self, op, fxn):
  135. for f in [_test_single_value, _test_single_value_const]:
  136. for a in [False, True]:
  137. for b in [False, True]:
  138. self._equal(f([a,b], op, (dtypes.bool, )*2), fxn(a,b))
  139. def _test_top_bool_fxn(self, op, fxn):
  140. for f in [_test_single_value, _test_single_value_const]:
  141. for a in [False, True]:
  142. for b in [False, True]:
  143. for c in [False, True]:
  144. self._equal(f([a,b,c], op, (dtypes.bool, )*3), fxn(a,b,c))
  145. def test_not_bool(self): self._test_uop_bool_fxn(UnaryOps.NEG, lambda a: not a)
  146. def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
  147. def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
  148. def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
  149. def test_and_bool(self): self._test_bop_bool_fxn(BinaryOps.AND, lambda a,b: a & b)
  150. def test_or_bool(self): self._test_bop_bool_fxn(BinaryOps.OR, lambda a,b: a | b)
  151. def test_cmpne_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPNE, lambda a,b: a != b)
  152. def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
  153. def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)
  154. class TestExecALU(TestUOps):
  155. def test_sqrt(self):
  156. self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.float, (0.0,)), 0.0)
  157. def test_div(self):
  158. self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (8, 2)), 4)
  159. self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, 3)), 2)
  160. self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (7, -3)), -2)
  161. self.assertEqual(exec_alu(BinaryOps.IDIV, dtypes.int8, (-50, 6)), -8)
  162. np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (3.0,)))), 2+(1.0/3.0))
  163. np.testing.assert_allclose(exec_alu(BinaryOps.MUL, dtypes.float32, (7.0, exec_alu(UnaryOps.RECIP, dtypes.float32, (-3.0,)))), -2-(1.0/3.0))
  164. def test_recip(self):
  165. np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (8,)), 1/8)
  166. np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (7,)), 1/7)
  167. np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-3,)), 1/-3)
  168. np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (-50,)), 1/-50)
  169. np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((32+521+3),)), 1/(32+521+3))
  170. np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, ((34**2),)), 1/(34**2))
  171. np.testing.assert_allclose(exec_alu(UnaryOps.RECIP, dtypes.float32, (10,)), 1/10)
  172. def test_bool_neg(self):
  173. self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (False,)), True)
  174. self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (True,)), False)
  175. def test_bool_cmplt(self):
  176. self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, False)), False)
  177. self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, True)), True)
  178. self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False)
  179. self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False)
  180. def test_bool_where(self):
  181. self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.bool, (False, False, False)), False)
  182. self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.int, (False, 2, 4)), 4)
  183. np.testing.assert_allclose(exec_alu(TernaryOps.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5)
  184. def test_overflow(self):
  185. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
  186. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)
  187. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (0, -1)), 255)
  188. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (0, -1000)), 24)
  189. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (127, 0)), 127)
  190. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
  191. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-100, -100)), 56)
  192. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-1000, -0)), 24)
  193. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-130, -0)), 126)
  194. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1, 1)), 2)
  195. self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128)
  196. class TestConstantFolding(unittest.TestCase):
  197. def test_cast_const(self):
  198. t = Tensor(1, dtype=dtypes.float).cast(dtypes.int)
  199. si = create_schedule([t.lazydata])
  200. assert len(si) == 0
  201. def test_bitcast_const(self):
  202. t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int)
  203. si = create_schedule([t.lazydata])
  204. assert len(si) == 1
  205. ji = lower_schedule_item(si[-1])
  206. assert any(uop.op is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
  207. class TestGatedStoreRewrite(unittest.TestCase):
  208. @unittest.skip("not yet implemented")
  209. def test_wrap_store_parents(self):
  210. # wraps all store parents in the valid branch
  211. gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
  212. gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
  213. idx = gidx0 * UOp.const(dtypes.int, 2)
  214. value = UOp(UOps.CONST, dtypes.float, (), 42.0)
  215. gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
  216. uops = UOpGraph([UOp(UOps.STORE, None, (gmem, idx, value, gate))])
  217. if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
  218. if_uop = next(u for u in uops if u.op is UOps.IF)
  219. endif = next(u for u in uops if u.op is UOps.ENDIF)
  220. assert endif.src[0] is if_uop
  221. nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
  222. assert nested_uops == (gmem, gidx0, idx, value)
  223. @unittest.skip("not yet implemented")
  224. def test_wrap_some_parents(self):
  225. # some parents are used outside the branch
  226. gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
  227. gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
  228. gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
  229. idx = gidx0 * UOp.const(dtypes.int, 2)
  230. value0 = UOp(UOps.CONST, dtypes.float, (), 42.0)
  231. value1 = UOp(UOps.CONST, dtypes.float, (), 43.0)
  232. gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
  233. outs = [UOp(UOps.STORE, None, (gmem0, idx, value0, gate))]
  234. outs.append(UOp(UOps.STORE, None, (gmem1, idx, value1)))
  235. uops = UOpGraph(outs)
  236. if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
  237. if_uop = next(u for u in uops if u.op is UOps.IF)
  238. endif = next(u for u in uops if u.op is UOps.ENDIF)
  239. assert endif.src[0] is if_uop
  240. nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
  241. assert nested_uops == (gmem0, value0)
  242. class TestLocalAccess(unittest.TestCase):
  243. # NOTE: this is failing on METAL CI, no idea why. Works locally.
  244. @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "failing only in CI")
  245. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
  246. def test_local_basic(self):
  247. uops = []
  248. smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16))
  249. st = uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
  250. barr = uop(uops, UOps.BARRIER, None, (st,))
  251. sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
  252. self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
  253. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
  254. def test_local_indirect(self):
  255. uops = []
  256. smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16))
  257. st1 = uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
  258. st2 = uop(uops, UOps.STORE, None, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
  259. barr = uop(uops, UOps.BARRIER, None, (st1,st2))
  260. ofs = uop(uops, UOps.LOAD, dtypes.int32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), barr))
  261. sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs))
  262. self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42)
  263. @unittest.skipUnless(getenv("PTX"), "This only tests assembly backends")
  264. class TestAssembly(unittest.TestCase):
  265. def test_bitshift_left(self):
  266. g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
  267. c1 = UOp(UOps.CONST, dtypes.int, (), 2)
  268. c2 = UOp(UOps.CONST, dtypes.int, (), 3)
  269. l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
  270. a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
  271. a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
  272. uops = UOpGraph([a1,a2])
  273. Device[Device.DEFAULT].renderer.render("test", uops)
  274. self.assertEqual(uops.uops[-1].arg, BinaryOps.SHL)
  275. self.assertEqual(uops.uops[-2].arg, BinaryOps.MUL)
  276. def test_bitshift_right(self):
  277. g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
  278. c1 = UOp(UOps.CONST, dtypes.int, (), 2)
  279. c2 = UOp(UOps.CONST, dtypes.int, (), 3)
  280. l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
  281. a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
  282. a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
  283. uops = UOpGraph([a1,a2])
  284. Device[Device.DEFAULT].renderer.render("test", uops)
  285. self.assertEqual(uops.uops[-1].arg, BinaryOps.SHR)
  286. self.assertEqual(uops.uops[-2].arg, BinaryOps.IDIV)
  287. class TestUOpCompare(unittest.TestCase):
  288. def test_alu_same_src_different_arg(self):
  289. a = UOp(UOps.CONST, dtypes.float, (), 2.0)
  290. b = UOp(UOps.CONST, dtypes.float, (), 3.0)
  291. add = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.ADD)
  292. mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
  293. assert (add < mul) or (mul < add), "add and mul with same src should have an order"
  294. if __name__ == '__main__':
  295. unittest.main(verbosity=2)