test_linearizer.py 114 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889
  1. from typing import List, Tuple, Dict, Union
  2. import numpy as np
  3. import unittest
  4. from dataclasses import replace
  5. from test.external.fuzz_linearizer import compare_linearizer
  6. from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
  7. from tinygrad.codegen.lowerer import get_grouped_dims
  8. from tinygrad.codegen.uops import UOp, UOps
  9. from tinygrad.device import Device, Buffer
  10. from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, MetaOps, TernaryOps, ReduceOps, UnaryOps
  11. from tinygrad.renderer import TensorCore
  12. from tinygrad.shape.shapetracker import ShapeTracker
  13. from tinygrad.shape.view import View
  14. from tinygrad.shape.symbolic import Variable
  15. from tinygrad.tensor import Tensor, _to_np_dtype
  16. from tinygrad.engine.schedule import create_schedule
  17. from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
  18. from tinygrad.engine.graph import print_tree
  19. from tinygrad.helpers import DEBUG, prod, Context, getenv, CI, flatten, dedup
  20. from tinygrad.dtype import DType, dtypes
  21. def helper_realized_ast(r:Union[Tensor, List[Tensor]]):
  22. if isinstance(r, Tensor): r = [r]
  23. s = create_schedule([x.lazydata for x in r])
  24. run_schedule(s[:-1]) # run all kernels except the last one
  25. # now all input LazyBuffers buffers in s[-1] should be realized
  26. # allocate an output buffer
  27. output_buffers = [Buffer((out).device, out.size, out.dtype).allocate() for out in s[-1].outputs]
  28. return s[-1].ast, output_buffers+list(s[-1].inputs)
  29. def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0):
  30. a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
  31. np_a, np_b = a.numpy(), b.numpy()
  32. r = a.matmul(b, acc_dtype=dtype_out)
  33. sched = create_schedule([r.lazydata])
  34. realized_ast = sched[-1].ast
  35. run_schedule(sched)
  36. out = r.numpy()
  37. k = Kernel(realized_ast)
  38. k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
  39. k.linearize()
  40. assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
  41. assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
  42. np_c = np_a @ np_b
  43. if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
  44. elif dtype_in == dtypes.bfloat16: tc_atol, tc_rtol = 1e-2, 3e-3
  45. else: tc_atol, tc_rtol = 5e-3, 1e-4
  46. np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
  47. def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True):
  48. a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
  49. r = a.matmul(b, acc_dtype=dtype_out)
  50. sched = create_schedule([r.lazydata])
  51. realized_ast = sched[-1].ast
  52. k = Kernel(realized_ast)
  53. k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
  54. k.linearize()
  55. wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA])
  56. tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
  57. if ensure_triggered:
  58. assert wmmas > 0, "tensor core not triggered"
  59. assert tcs == 1, "tensor core opt not included"
  60. else:
  61. assert wmmas == 0, "tensor core is incorrectly triggered"
  62. assert tcs == 0, "tensor core opt is incorrectly included"
  63. class TestLinearizer(unittest.TestCase):
  64. def test_arg_dedup(self):
  65. a, b = Tensor.randn(4), Tensor.randn(4)
  66. np_a, np_b = a.numpy(), b.numpy()
  67. c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
  68. lowered = list(lower_schedule(create_schedule([c.lazydata])))
  69. for ei in lowered: ei.run()
  70. rawbufs = lowered[-1].bufs
  71. assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
  72. np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
  73. np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
  74. def test_load_removed(self):
  75. a = Tensor.rand(1).realize()
  76. b = Tensor.rand(1).realize()
  77. ta = Tensor.where(Tensor(True), a, b).numpy()
  78. tb = Tensor.where(Tensor(False), a, b).numpy()
  79. np.testing.assert_equal(a.numpy(), ta)
  80. np.testing.assert_equal(b.numpy(), tb)
  81. def test_multioutput(self):
  82. dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
  83. a = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtype, st=st))
  84. b = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtype, st=st))
  85. out0 = LazyOp(BufferOps.STORE, (LazyOp(op=BinaryOps.ADD, src=(a,b)),), MemBuffer(idx=0, dtype=dtype, st=st))
  86. out1 = LazyOp(BufferOps.STORE, (LazyOp(op=BinaryOps.MUL, src=(a,b)),), MemBuffer(idx=1, dtype=dtype, st=st))
  87. a_t = Tensor.full(st.shape, 2).contiguous().realize()
  88. b_t = Tensor.full(st.shape, 3).contiguous().realize()
  89. lin = helper_linearizer_ast((out0, out1), [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0]
  90. stores = [u for u in lin.uops if u.op is UOps.STORE]
  91. mutable_bufs = [u for u in lin.uops if u.op is UOps.DEFINE_GLOBAL and u.arg[-1]]
  92. assert len(mutable_bufs) == len(stores) == 2
  93. assert [u.arg[0] for u in mutable_bufs] == [0, 1]
  94. @unittest.skip("TODO: fix uops toposort")
  95. def test_sum_multireduce(self):
  96. Tensor.manual_seed(0)
  97. x = Tensor.randn(32, dtype=dtypes.float).realize()
  98. first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((1, 32)).expand((32, 32))))
  99. first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (1,))
  100. second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((32, 1))))
  101. squares = (second_x-first_reduce)
  102. squares_sum = LazyOp(ReduceOps.SUM, (squares,), (0,))
  103. store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((1, 1))))
  104. wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1)
  105. helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
  106. @unittest.skip("TODO: fix uops toposort")
  107. def test_double_sum_multireduce(self):
  108. Tensor.manual_seed(0)
  109. x = Tensor.randn(2, 32, 4, 16, dtype=dtypes.float).realize()
  110. first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((2, 1, 32, 4, 1, 16)).expand((2, 32, 32, 4, 16, 16))))
  111. first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,5))
  112. second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((2, 32, 1, 4, 16, 1))))
  113. squares = (second_x-first_reduce)
  114. squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,4))
  115. store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((2, 1, 1, 4, 1, 1))))
  116. wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((2,1,1,4,1,1))
  117. helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
  118. @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "ocelot/remu doesn't have multiple wave syncs yet")
  119. @unittest.skip("TODO: fix uops toposort")
  120. def test_var_multireduce(self):
  121. Tensor.manual_seed(0)
  122. x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
  123. # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
  124. first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32))))
  125. first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
  126. mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 32, 1)))) # noqa: E501
  127. # store = LazyOp(BufferOps.STORE, (mean,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 32, 1))))
  128. # verify_lazyop(store)
  129. second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1))))
  130. squares = (second_x-mean)*(second_x-mean)
  131. squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
  132. variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501
  133. store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
  134. wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1))
  135. helper_linearizer_ast((store, ), [x], wanna_output=[wanna_output])
  136. # tinygrad ref
  137. y_tiny = x.var(axis=2, correction=0).reshape(3,27,1,1)
  138. np.testing.assert_allclose(y_tiny.numpy(), wanna_output, atol=1e-4, rtol=1e-4)
  139. # *** buildup to fused indexing
  140. @unittest.skipIf(CI, "very slow because of recomputing")
  141. def test_arange_expanded(self):
  142. # Tensor.arange(16384) expanded such that output shape is (4, 16384, 256, 1)
  143. # basically it's pushing the expand through this reduce:
  144. tiny = Tensor.arange(16384).reshape(16384, 1).expand(4, 16384, 256).reshape(4, 16384, 256, 1)
  145. real_arange = np.broadcast_to(np.arange(16384).reshape(16384, 1), (4, 16384, 256)).reshape(4, 16384, 256, 1)
  146. # NOTE: this is stupidly recomputing because it's not fused, but it proves a point.
  147. arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \
  148. View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
  149. arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
  150. arange_axis = (3,)
  151. arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis)
  152. output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
  153. out = arange-LazyOp.const(1, dtypes.int, output_shape)
  154. store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(0, dtypes.int, st=ShapeTracker.from_shape(output_shape)))
  155. helper_linearizer_ast((store, ), [], wanna_output=[real_arange])
  156. with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange)
  157. @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow")
  158. def test_indexing_multireduce(self):
  159. arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \
  160. View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
  161. # TODO: do this arange broadcast in the scheduler
  162. arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
  163. arange_axis = (3,)
  164. arange = LazyOp(ReduceOps.SUM, (LazyOp(BufferOps.CONST, (), ConstBuffer(1, dtypes.int, arange_input_st)), ), arange_axis)
  165. arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
  166. arange = arange-LazyOp.const(1, dtypes.int, arange_out_shape)
  167. # p2: the indexing
  168. dataset = Tensor.rand(16384, 256).realize()
  169. data1 = MemBuffer(1, dataset.dtype, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape))
  170. idxs = Tensor([0,3,5,6]).realize()
  171. data2 = MemBuffer(2, dtypes.int, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape))
  172. reduce_input = LazyOp(BufferOps.LOAD, (), data1)*LazyOp(UnaryOps.CAST, (arange.eq(LazyOp(BufferOps.LOAD, (), data2)),), dataset.dtype)
  173. out = LazyOp(ReduceOps.SUM, (reduce_input, ), (1,))
  174. output_shape = tuple(1 if i in out.arg else s for i,s in enumerate(arange_out_shape))
  175. store = LazyOp(BufferOps.STORE, (out, ), MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape(output_shape)))
  176. real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1)
  177. helper_linearizer_ast((store, ), [dataset, idxs], wanna_output=[real_index])
  178. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  179. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  180. def test_end_local(self):
  181. load = MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker.from_shape((32,)))
  182. store = MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker.from_shape((1,)))
  183. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, arg=load),), arg=(0,)),), arg=store),
  184. load_t = Tensor.full(load.st.shape, 1).contiguous().realize()
  185. k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1]
  186. self.assertEqual(k.uops[-1].op, UOps.ENDIF)
  187. self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.op is UOps.STORE][-1]), k.uops.uops.index(k.uops[-1]))
  188. def test_two_nested_range(self):
  189. a = Tensor.randn(2, ).realize()
  190. out = a.reshape(2, 1).expand(2, 3).sum()
  191. lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0]
  192. ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
  193. # RANGE -> LOAD -> RANGE -> PHI
  194. assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
  195. def test_three_nested_range(self):
  196. a = Tensor.randn(2, ).realize()
  197. out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
  198. lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0]
  199. ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
  200. # RANGE -> RANGE -> LOAD -> RANGE -> PHI
  201. # NOTE: nothing should toposort between the first two ranges
  202. assert ranges[0]+1 == ranges[1]
  203. assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]])
  204. def test_two_nested_range_alt_indexing(self):
  205. a = Tensor([2, 2]).realize()
  206. out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum()
  207. lin = helper_linearizer_opt(out, wanna_output=[24])[0]
  208. ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
  209. # RANGE -> ALU -> RANGE -> ALU + LOAD -> PHI
  210. assert any(x.op is UOps.ALU for x in lin.uops[ranges[0]:ranges[1]])
  211. assert not any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
  212. assert any(x.op in {UOps.ALU, UOps.LOAD} for x in lin.uops[ranges[1]:])
  213. def test_range_outer_op_before_phi(self):
  214. a = Tensor.randn(4, 1).realize()
  215. b = Tensor.randn(1, 1).realize()
  216. out = (a + b[0]).sum() + b[0]
  217. lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0]
  218. ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
  219. # LOAD -> RANGE -> LOAD -> PHI
  220. assert lin.uops[ranges[0]-2].op is UOps.LOAD
  221. # TODO: this test is brittle
  222. def test_range_outer_op_before_phi_nested_range(self):
  223. a = Tensor.randn(2, ).realize()
  224. b = Tensor.randn(1, 1).realize()
  225. out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
  226. lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])[0]
  227. ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
  228. if getenv("PTX"):
  229. # LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> PHI
  230. assert lin.uops[ranges[0]-2].op is UOps.LOAD
  231. assert ranges[1] == ranges[0]+6
  232. assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
  233. # LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI
  234. else:
  235. assert lin.uops[ranges[0]-2].op is UOps.LOAD
  236. assert ranges[1] == ranges[0]+3
  237. assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
  238. def test_range_outer_op_after_phi(self):
  239. a = Tensor.randn(4, 1).realize()
  240. out = a.sum() * a.sum()
  241. lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0]
  242. # RANGE -> LOAD -> PHI -> ALU
  243. end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
  244. assert lin.uops[end+1].op is UOps.ALU
  245. def test_range_outer_op_after_phi_nested_range(self):
  246. a = Tensor.randn(2, ).realize()
  247. out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum()
  248. lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0]
  249. # RANGE -> LOAD -> PHI -> ALU
  250. end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
  251. assert lin.uops[end+1].op is UOps.ALU
  252. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  253. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  254. @unittest.skip("AST has implicit movement ops")
  255. def test_early_end_local(self):
  256. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
  257. k = Kernel(ast)
  258. k.hand_coded_optimizations()
  259. k.linearize()
  260. self.assertEqual(len(endifs:=[x for x in k.uops if x.op is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.op is UOps.IF]))
  261. self.assertEqual(len(barriers:=[x for x in k.uops if x.op is UOps.BARRIER]), 3)
  262. self.assertEqual(k.uops[k.uops.uops.index(endifs[0])-1].op, UOps.STORE)
  263. self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+1], barriers[1])
  264. self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+2].op, UOps.LOAD)
  265. self.assertLess(k.uops.uops.index(barriers[0]), k.uops.uops.index(ifs[0]))
  266. self.assertLess(k.uops.uops.index(ifs[0]), k.uops.uops.index(endifs[0]))
  267. self.assertLess(k.uops.uops.index(barriers[1]), k.uops.uops.index(ifs[1]))
  268. x = Tensor.randn(3,27,32).realize()
  269. helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(axis=2, ddof=0).reshape(-1)])
  270. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  271. @unittest.skip("AST has implicit movement ops")
  272. def test_reduceops_order(self):
  273. # make sure that the kernel put reduceops in the order of their dependencies when passed to the Linearizer in arbitrary order
  274. load = MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker.from_shape((32,)))
  275. ast0 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=load),), arg=(0,))
  276. ast1 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=load), \
  277. LazyOp(op=UnaryOps.NEG, src=(ast0,), arg=None))),), arg=(0,))
  278. ast2 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(BinaryOps.ADD, src=(ast1, LazyOp(op=UnaryOps.NEG, \
  279. src=(LazyOp(op=BufferOps.LOAD, src=(), arg=load),), arg=None))),), arg=(0,))
  280. ast3 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=load), LazyOp(op=UnaryOps.NEG, src=(ast2,), arg=None))), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=load), LazyOp(op=UnaryOps.NEG, src=(ast0,), arg=None))),)),), arg=(0,)) # noqa E501
  281. for order in [(d, c, b, a) for d in range(4) for c in range(4) for b in range(4) for a in range(4) if len(set([a,b,c,d])) == 4]:
  282. asts = [
  283. LazyOp(op=BufferOps.STORE, src=(ast0,), arg=MemBuffer(idx=order.index(0), dtype=dtypes.float, st=ShapeTracker.from_shape((1,)))),
  284. LazyOp(op=BufferOps.STORE, src=(ast1,), arg=MemBuffer(idx=order.index(1), dtype=dtypes.float, st=ShapeTracker.from_shape((1,)))),
  285. LazyOp(op=BufferOps.STORE, src=(ast2,), arg=MemBuffer(idx=order.index(2), dtype=dtypes.float, st=ShapeTracker.from_shape((1,)))),
  286. LazyOp(op=BufferOps.STORE, src=(ast3,), arg=MemBuffer(idx=order.index(3), dtype=dtypes.float, st=ShapeTracker.from_shape((1,))))
  287. ]
  288. k = Kernel([asts[i] for i in order])
  289. def recursive_reduceops(x: LazyOp): return [c for v in x.src for c in recursive_reduceops(v)] + [v for v in list(x.src) if v.op in ReduceOps]
  290. for i,r in enumerate(k.reduceops): assert not any([r in recursive_reduceops(x) for x in k.reduceops[:i]]), "reduceops are out of order"
  291. x = Tensor.randn(32).realize()
  292. outs = [b:=(a:=x.numpy()).sum(), c:=(a - b).sum(), d:=(c - a).sum(), (a-d + a-b).sum()]
  293. helper_linearizer_ast(tuple(asts[i] for i in order), [x], wanna_output=[outs[i] for i in order])
  294. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  295. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  296. @unittest.skip("AST has implicit movement ops")
  297. def test_multireduce_store_locals(self):
  298. # ensure the result of local reducop is stored and loaded back into every thread for future use
  299. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
  300. k = Kernel(ast)
  301. k.hand_coded_optimizations()
  302. k.linearize()
  303. local_buf = [u for u in k.uops if u.op is UOps.DEFINE_LOCAL]
  304. self.assertEqual(len(real_local_stores:=[u for u in k.uops if u.op is UOps.STORE and any([lb in u.src for lb in local_buf])]), 3, \
  305. f"should have generated 3 BufferOps.STORE to the local buf but got {len(real_local_stores)}")
  306. self.assertEqual(len(real_local_loads:=[u for u in k.uops if u.op is UOps.LOAD and any([lb in u.src for lb in local_buf])]), 3, \
  307. f"should have generated 3 BufferOps.LOAD to the local buf but got {len(real_local_loads)}")
  308. self.assertEqual((real_local_stores[1].src[1].op, real_local_stores[1].src[1].arg), (UOps.CONST, 0))
  309. self.assertEqual((real_local_loads[1].src[1].op, real_local_loads[1].src[1].arg), (UOps.CONST, 0))
  310. x = Tensor.randn(3,27,32).realize()
  311. helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(axis=2, ddof=0).reshape(-1)])
  312. @unittest.skip("AST has implicit movement ops")
  313. def test_multireduce_upcasting(self):
  314. # when upcasting multiple reductions, ensure ast_parse will create multiple uops even when using the result of past reductions
  315. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),), arg=None),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
  316. k = Kernel(ast)
  317. k.upcast()
  318. k.linearize()
  319. define_globals = [u for u in k.uops if u.op is UOps.DEFINE_GLOBAL]
  320. self.assertEqual(len([u for u in k.uops if u.op is UOps.LOAD and define_globals[1] in u.src]), 7)
  321. self.assertEqual(len([u for u in k.uops if u.op is UOps.ALU and u.arg is BinaryOps.ADD]), 25)
  322. opts = [[Opt(op=OptOps.UPCAST, axis=0, amt=2)], [Opt(op=OptOps.UPCAST, axis=0, amt=4)]]
  323. x = Tensor.randn(8,7).softmax().realize()
  324. helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)])
  325. @unittest.skip("TODO: fix uops toposort")
  326. def test_multireduce_unroll(self):
  327. # unrolled multireduceops will cause an issue where and reduceop following another reduceop will need to bring the "unroll" back:
  328. # ex you unroll into four values, the four values sum, then you need to four operations on the sum for the next reduceop
  329. Tensor.manual_seed(0)
  330. x = Tensor.randn(3, 27, 12, dtype=dtypes.float).realize()
  331. first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 1, 12)).expand((3, 27, 12, 12))))
  332. first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
  333. mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/12, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 12, 1)))) # noqa: E501
  334. second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 12, 1))))
  335. squares = (second_x-mean)*(second_x-mean)
  336. squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
  337. variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/12, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 1, 1)))) # noqa: E501
  338. store = LazyOp(BufferOps.STORE, (variance,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
  339. wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1))
  340. opts = [
  341. [Opt(op=OptOps.UNROLL, axis=0, amt=12)],
  342. [Opt(op=OptOps.UNROLL, axis=0, amt=6)],
  343. [Opt(op=OptOps.UNROLL, axis=0, amt=4)],
  344. [Opt(op=OptOps.UNROLL, axis=0, amt=3)],
  345. [Opt(op=OptOps.UNROLL, axis=0, amt=2)],
  346. ]
  347. helper_linearizer_ast((store,), [x], opts=opts, wanna_output=[wanna_output])
  348. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  349. @unittest.skip("AST has implicit movement ops")
  350. def test_multireduce_loop_scope(self):
  351. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None))), LazyOp(op=UnaryOps.RECIP, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),)),),),), arg=(2,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),),))), # noqa: E501
  352. k = Kernel(ast)
  353. k.hand_coded_optimizations()
  354. k.linearize()
  355. def get_recursive_children(x:UOp): return set.union(set(x.src), *[get_recursive_children(v) for v in x.src])
  356. loop = None
  357. for u in k.uops:
  358. if u.op is UOps.RANGE: loop = u
  359. elif loop is None: continue
  360. elif u.op is UOps.ENDRANGE and loop in u.src: loop = None
  361. else: self.assertIn(loop, get_recursive_children(u), f"Any uop within a loop should depend on the loop: {u}")
  362. x = Tensor.randn(3, 27, 32).realize()
  363. helper_linearizer_ast(ast, [x], wanna_output= \
  364. [((x.numpy() - x.numpy().mean(axis=2, keepdims=True))/x.numpy().std(axis=2, keepdims=True, ddof=0)).sum(axis=2).reshape(-1)])
  365. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  366. @unittest.skip("TODO: fix uops toposort")
  367. def test_mean_std_multireduce(self):
  368. Tensor.manual_seed(0)
  369. x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
  370. first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))))
  371. first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
  372. mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501
  373. second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1))))
  374. squares = (second_x-mean)*(second_x-mean)
  375. squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
  376. variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
  377. std = LazyOp(UnaryOps.SQRT, (variance,), None)
  378. store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1))))
  379. wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1))
  380. helper_linearizer_ast((store,), [x], wanna_output=[wanna_output])
  381. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  382. @unittest.skip("TODO: fix uops toposort")
  383. def test_mean_std_multireduce_mid_dim(self):
  384. Tensor.manual_seed(0)
  385. x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
  386. first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35))))
  387. first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (2,))
  388. mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 35)))) # noqa: E501
  389. second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35))))
  390. squares = (second_x-mean)*(second_x-mean)
  391. squares_sum = LazyOp(ReduceOps.SUM, (squares,), (1,))
  392. variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(0.04, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 1, 1, 35)))) # noqa: E501
  393. std = LazyOp(UnaryOps.SQRT, (variance,), None)
  394. store = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 1, 1, 35))))
  395. wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35))
  396. helper_linearizer_ast((store,), [x], wanna_output=[wanna_output])
  397. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  398. @unittest.expectedFailure
  399. def test_mean_std_multireduce_multiout(self):
  400. Tensor.manual_seed(0)
  401. x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
  402. first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35))))
  403. first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
  404. mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 35, 1)))) # noqa: E501
  405. second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, x.lazydata.st.reshape((15, 25, 35, 1))))
  406. squares = (second_x-mean)*(second_x-mean)
  407. squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
  408. variance = squares_sum * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
  409. std = LazyOp(UnaryOps.SQRT, (variance,), None)
  410. third_reduce = LazyOp(ReduceOps.SUM, (second_x,), (2,))
  411. mean_out = third_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(1/35, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((15, 25, 1, 1)))) # noqa: E501
  412. store_mean = LazyOp(BufferOps.STORE, (mean_out,), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((15,25,1,1))))
  413. store_std = LazyOp(BufferOps.STORE, (std,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((15, 25, 1, 1))))
  414. wanna_output = [x.numpy().std(axis=2, ddof=0).reshape(15,25,1,1), x.numpy().mean(axis=2).reshape(15,25,1,1)]
  415. lins = helper_linearizer_ast((store_std,store_mean), [x], wanna_output=wanna_output)
  416. for k in lins:
  417. assert len([u for u in k.uops if u.op is UOps.DEFINE_ACC]) == 2, "got more than two accs (didn't reuse the mean reduce)"
  418. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  419. @unittest.skip("AST has implicit movement ops")
  420. def test_softmax_multireduce(self):
  421. x = Tensor.rand(4, 32).realize()
  422. x_ast = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((4,32))))
  423. max_x = LazyOp(op=ReduceOps.MAX, src=(x_ast,), arg=(1,))
  424. centered_x = LazyOp(op=BinaryOps.ADD, src=(x_ast, LazyOp(op=UnaryOps.NEG, src=(max_x,), arg=None)))
  425. exp_x = LazyOp(op=UnaryOps.EXP2, src=(centered_x,))
  426. sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,))
  427. y = LazyOp(op=BinaryOps.MUL, src=(exp_x, LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,))))
  428. y_reduced = LazyOp(op=ReduceOps.SUM, src=(y,), arg=(1,))
  429. ast = LazyOp(op=BufferOps.STORE, src=(y_reduced,), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1))))
  430. expected = ((np_exp2:=np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)))/np_exp2.sum(axis=-1, keepdims=True)).sum(axis=-1)
  431. helper_linearizer_ast((ast,), [x], wanna_output=[expected])
  432. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  433. @unittest.skip("AST has implicit movement ops")
  434. def test_softmax_multireduce_multiout(self):
  435. x = Tensor.rand(4, 32).realize()
  436. x_ast = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((4,32))))
  437. max_x = LazyOp(op=ReduceOps.MAX, src=(x_ast,), arg=(1,))
  438. exp_x = LazyOp(op=UnaryOps.EXP2, src=(LazyOp(op=BinaryOps.ADD, src=(x_ast, LazyOp(op=UnaryOps.NEG, src=(max_x,), arg=None))),))
  439. sum_exp_x = LazyOp(op=ReduceOps.SUM, src=(exp_x,), arg=(1,))
  440. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(exp_x, LazyOp(op=UnaryOps.RECIP, src=(sum_exp_x,)))),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1)))) # noqa: E501
  441. max_x_ast = LazyOp(op=BufferOps.STORE, src=(max_x,), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1))))
  442. sum_exp_x_ast = LazyOp(op=BufferOps.STORE, src=(sum_exp_x,), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((4,1))))
  443. expected = [
  444. ((np_exp2:=np.exp2(x.numpy()-(np_max_x:=x.numpy().max(axis=-1,keepdims=True))))/(sum_exp_x:=np_exp2.sum(axis=-1,keepdims=True))).sum(axis=-1,),
  445. np_max_x.reshape(-1), sum_exp_x.reshape(-1)
  446. ]
  447. helper_linearizer_ast((ast,max_x_ast,sum_exp_x_ast), [x], wanna_output=expected)
  448. def test_load_dedup(self):
  449. # for different leaves in the AST, the same loads may occur.
  450. a = Tensor.randn(4).realize()
  451. # these are of size 3 to avoid float4 coalesce
  452. r = a[:-1] + a[1:]
  453. k = Kernel(create_schedule([r.lazydata])[-1].ast)
  454. k.upcast()
  455. k.linearize()
  456. num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD])
  457. assert num_loads <= 4, "more load uops than needed"
  458. assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
  459. def test_load_cache_const_bufs(self):
  460. # make sure const buffers are differentiated from local and mem buffers
  461. ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)), dtypes.int
  462. VAL = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=2, dtype=DT, st=ST))
  463. # data1[0] + VAL
  464. a = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=DT, st=ST)), VAL))
  465. # (literal const 1) + VAL
  466. b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL))
  467. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST))
  468. lin = Kernel(ast)
  469. lin.linearize()
  470. assert len(lin.uops.uops) <= 7, "too many uops"
  471. a_bufs = [u.op for u in lin.uops.uops[-1].src[2].src]
  472. assert a_bufs == [UOps.LOAD, UOps.CONST]
  473. def test_upcast_cse(self):
  474. # when upcasting, within a subtree, there may be common expressions.
  475. a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
  476. r = a.expand([2]) + b.expand([2])
  477. k = Kernel(create_schedule([r.lazydata])[-1].ast)
  478. k.upcast()
  479. k.linearize()
  480. num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
  481. assert num_ops <= 1, "more alu uops than needed"
  482. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  483. def test_reduce_upcast(self):
  484. x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
  485. r = Tensor.conv2d(x,w,padding=1).relu()
  486. k = Kernel(create_schedule([r.lazydata])[-1].ast)
  487. k.upcast()
  488. k.upcast()
  489. k.linearize()
  490. accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC]
  491. stores = [u for u in k.uops if u.op is UOps.STORE]
  492. assert len(accs) == 0 # it's removed now
  493. assert len(stores) == 1
  494. assert stores[0].src[-1].dtype == dtypes.float.vec(4)
  495. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  496. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  497. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  498. def test_upcast_with_locals(self):
  499. x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
  500. r = (x@y).relu()
  501. k = Kernel(create_schedule([r.lazydata])[-1].ast)
  502. k.hand_coded_optimizations()
  503. k.linearize()
  504. accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC]
  505. stores = [u for u in k.uops if u.op is UOps.STORE]
  506. # the first store is to lds and can be upcasted
  507. assert accs[0].dtype == stores[0].src[-1].dtype == dtypes.float.vec(4)
  508. assert stores[0].src[0].op is UOps.DEFINE_LOCAL
  509. # the second store is to gds with no upcasts
  510. assert accs[1].dtype == stores[1].src[2].dtype == dtypes.float
  511. assert stores[1].src[0].op is UOps.DEFINE_GLOBAL
  512. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  513. @unittest.skip("AST has implicit movement ops")
  514. def test_upcast_multireduce_nested_local_upcast(self):
  515. x, y, z, w = [Tensor.rand((1,128) if i % 2 == 0 else (1,128,128)).realize() for i in range(4)]
  516. st0 = ShapeTracker(views=(View(shape=(1, 128, 128), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),))
  517. st1 = ShapeTracker(views=(View(shape=(1, 128, 128), strides=(0, 1, 128), offset=0, mask=None, contiguous=False),))
  518. ld0 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, st0))
  519. ld1 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.float, st1))
  520. ld2 = LazyOp(BufferOps.LOAD, (), MemBuffer(3, dtypes.float, st0))
  521. ld3 = LazyOp(BufferOps.LOAD, (), MemBuffer(4, dtypes.float, st1))
  522. r0 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.MUL, (ld0, ld1)), ), (2,))
  523. r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.MUL, (ld2, ld3)), ), (2,))
  524. out_st = ShapeTracker(views=(View(shape=(1, 128, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),))
  525. ast = (LazyOp(BufferOps.STORE, (LazyOp(BinaryOps.ADD, (r0, r1)), ), MemBuffer(0, dtypes.float, out_st)),)
  526. helper_linearizer_ast(ast, [x, y, z, w])
  527. def test_zero_fold(self):
  528. a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
  529. r = Tensor.stack(a, b)
  530. k = Kernel(create_schedule([r.lazydata])[-1].ast)
  531. k.upcast()
  532. k.linearize()
  533. num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
  534. assert num_ops == 0, "more alu uops than needed"
  535. def test_sum_acc_dtype(self):
  536. for tensor_dtype, acc_dtype in (
  537. (dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
  538. a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
  539. k = Kernel(create_schedule([a.lazydata])[-1].ast)
  540. k.linearize()
  541. local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
  542. assert local[0].dtype == acc_dtype
  543. def test_arg_acc_dtype(self):
  544. def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
  545. k = Kernel(create_schedule([c.lazydata])[-1].ast)
  546. k.linearize()
  547. local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
  548. assert local[0].dtype == expected_dtype
  549. tests = (
  550. (dtypes.float16, None, dtypes.float),
  551. (dtypes.bfloat16, None, dtypes.float),
  552. (dtypes.float, None, dtypes.float),
  553. (dtypes.float16, dtypes.float16, dtypes.float16),
  554. (dtypes.bfloat16, dtypes.bfloat16, dtypes.bfloat16),
  555. (dtypes.float, dtypes.float16, dtypes.float16),
  556. )
  557. for tensor_dtype, acc_dtype, expected_dtype in tests:
  558. a, b = Tensor.rand(8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, dtype=tensor_dtype)
  559. helper_arg_acc_dtype(a.sum(acc_dtype=acc_dtype), expected_dtype)
  560. helper_arg_acc_dtype(a.matmul(b, acc_dtype=acc_dtype), expected_dtype)
  561. helper_arg_acc_dtype(Tensor.einsum("ki,ij->kj", a, b, acc_dtype=acc_dtype), expected_dtype)
  562. d, w = Tensor.rand(4, 8, 8, 8, dtype=tensor_dtype), Tensor.rand(8, 8, 2, 2, dtype=tensor_dtype)
  563. helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
  564. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  565. def test_tensor_cores(self):
  566. for tc in Device[Device.DEFAULT].renderer.tensor_cores:
  567. if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
  568. helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
  569. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  570. def test_tensor_cores_padded(self):
  571. for tc in Device[Device.DEFAULT].renderer.tensor_cores:
  572. if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
  573. pad = 1
  574. # check that TC is triggered for TC_OPT=2
  575. helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad,
  576. tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=True)
  577. # check that TC is not triggered for TC_OPT<2
  578. helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad,
  579. tc.dtype_in, tc.dtype_out, tc_opt=1, ensure_triggered=False)
  580. helper_tc_ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad,
  581. tc.dtype_in, tc.dtype_out, tc_opt=0, ensure_triggered=False)
  582. # check excessive padding doesn't trigger padded TC in TC_OPT=2
  583. helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
  584. helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
  585. helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//4, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
  586. # check correctness
  587. helper_tc_allclose(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
  588. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
  589. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  590. def test_tensor_cores_multi_reduce(self):
  591. for tc in Device[Device.DEFAULT].renderer.tensor_cores:
  592. if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue
  593. # this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
  594. golden_result = None
  595. for axis in range(9):
  596. a = Tensor.rand(16, 16, 29, 29, dtype=tc.dtype_in).realize()
  597. b = Tensor.rand(32, 16, 16, 16, dtype=tc.dtype_in).realize()
  598. c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out)
  599. realized_ast, real_bufs = helper_realized_ast(c)
  600. k = Kernel(realized_ast)
  601. k.apply_tensor_cores(1, axis=axis, tc_opt=2)
  602. k.linearize()
  603. assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
  604. assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
  605. prg = CompiledRunner(k.to_program())
  606. real_bufs[0].copyin(np.zeros((real_bufs[0].size, ), dtype=_to_np_dtype(real_bufs[0].dtype)).data) # Zero to check that all values are filled
  607. prg.exec(real_bufs)
  608. result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype))
  609. # ensure the results for each choice of axis matches
  610. if golden_result is None: golden_result = np.frombuffer(real_bufs[0].as_buffer(), _to_np_dtype(real_bufs[0].dtype))
  611. np.testing.assert_allclose(result, golden_result, atol=0.1, rtol=0.15)
  612. # check that get_kernel_actions produces all 9 options
  613. from tinygrad.engine.search import get_kernel_actions
  614. tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
  615. assert len(tc_actions) == 9, f"get_kernel_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
  616. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  617. def test_tensor_cores_unroll_phi(self):
  618. tc = Device[Device.DEFAULT].renderer.tensor_cores[0]
  619. x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
  620. r = x.matmul(y, acc_dtype=tc.dtype_out)
  621. k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
  622. for u in k.uops:
  623. if u.op is UOps.WMMA:
  624. assert u.src[-1].src[0].op != UOps.PHI
  625. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  626. def test_tensor_cores_unroll_casted_phi(self):
  627. tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
  628. x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
  629. r = x.matmul(y, acc_dtype=tc.dtype_out)
  630. k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
  631. for u in k.uops:
  632. if u.op is UOps.WMMA:
  633. assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
  634. assert u.src[-1].src[0].op != UOps.PHI
  635. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  636. def test_tensor_cores_unroll_casted_phi_with_children(self):
  637. # all PHI children are outside the loop
  638. tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
  639. x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
  640. r = x.matmul(y, acc_dtype=tc.dtype_out).relu()
  641. k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
  642. for u in k.uops:
  643. if u.op is UOps.WMMA:
  644. assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
  645. assert u.src[-1].src[0].op != UOps.PHI
  646. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  647. def test_simple_unroll_no_between_phi_dependencies(self):
  648. x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
  649. r = (x@y).relu()
  650. k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1]
  651. # the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x PHI -> ENDRANGE
  652. for u in k.uops:
  653. if u.op is UOps.PHI:
  654. assert u.src[1].op is UOps.ALU
  655. # children of PHI are placed after ENDRANGE
  656. if any(x.op is UOps.PHI for x in u.src):
  657. end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0]
  658. assert end_range < k.uops.uops.index(u)
  659. def test_grouped_dims(self):
  660. def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
  661. # TODO: fix reverse_dims
  662. idxs = get_grouped_dims(prefix, dims, max_sizes)
  663. loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
  664. sizes = [x.arg[2] for x in loop_idxs]
  665. assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
  666. assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
  667. assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
  668. # TODO: add these back after uop symbolic
  669. # for i in range(len(dims)):
  670. # assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
  671. # for i in range(len(loop_idxs)):
  672. # assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
  673. # assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
  674. # no-op
  675. _assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
  676. _assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
  677. # check reverse dims
  678. # _assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
  679. _assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
  680. # test splitting globals
  681. # _assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
  682. # _assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,4,12])
  683. # _assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [12,16,4])
  684. # _assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,4,24])
  685. # collapse on onto the left most axis
  686. _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
  687. # _assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
  688. # _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
  689. # collapse on left-most available axis (the left most is too small)
  690. # _assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
  691. # _assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
  692. _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
  693. # # dim too large and not factorable
  694. # with self.assertRaises(AssertionError):
  695. # get_grouped_dims("gidx", 0, (23,), (16,16,16), False,)
  696. # with self.assertRaises(AssertionError):
  697. # get_grouped_dims("gidx", 0, (128,3,4), (16,4,23), False,)
  698. # # too large for sizes
  699. # with self.assertRaises(AssertionError):
  700. # get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16), False,)
  701. # # variable too large
  702. # with self.assertRaises(AssertionError):
  703. # get_grouped_dims("gidx", 0, (Variable("start_pos",0,16),3,4), (16,16,16), False,)
  704. def test_div_collapse(self):
  705. def helper(t, msg, max_ops=0):
  706. sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.KERNEL]
  707. assert len(sched) == 1
  708. lin = Kernel(sched[0].ast)
  709. assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
  710. a = Tensor.rand((4,4))
  711. b = Tensor.rand((4,4))
  712. d = Tensor.rand((4,4))
  713. c = (a*b)/b
  714. helper(c, "found UnaryOps.RECIP in (a*b)/b operation")
  715. c = a/a
  716. helper(c, "found UnaryOps.RECIP in (a/a) operation")
  717. c = (a/b)/d
  718. helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1)
  719. def test_sum_collapse(self):
  720. t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
  721. sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.KERNEL]
  722. assert len(sched) == 1
  723. lin = Kernel(sched[0].ast)
  724. assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
  725. def test_assign_fold(self):
  726. a = Tensor.ones(4, 4).contiguous().realize()
  727. m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None))
  728. a.assign(a+m)
  729. a.realize()
  730. np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
  731. def test_where_fold(self):
  732. a = Tensor.ones(4, 4).contiguous().realize()
  733. b = a.shrink(((1, 2), None)).pad(((1, 2), None))
  734. a.assign(b.where(2, a))
  735. sched = create_schedule([a.lazydata])
  736. assert len(sched) == 1
  737. sched_copy = sched[:]
  738. run_schedule(sched)
  739. np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
  740. lin = Kernel(sched_copy[-1].ast)
  741. lin.hand_coded_optimizations()
  742. lin.linearize()
  743. assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
  744. def test_phi_simplification(self):
  745. def helper(t, max_ops=0):
  746. k = helper_linearizer_opt(t)[-1]
  747. uops = list(k.linearize().uops)
  748. # ignore kernel optimized IF statements for now
  749. if if_op:=next((u for u in uops if u.op is UOps.IF), None):
  750. uops = uops[:uops.index(if_op)]
  751. assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both"
  752. assert len([u for u in uops if u.op is UOps.PHI]) == 0, "PHI should have been simplified"
  753. # TODO: once uops track min/max this will be fixed
  754. #assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
  755. helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2)
  756. helper(Tensor.arange(-1, -100, -5), max_ops=2)
  757. # NOTE: both of these split the reduce (this just wasn't tracked before)
  758. #helper(Tensor.arange(-3.2, 6.7, 0.64), max_ops=2)
  759. #helper(Tensor.arange(256), max_ops=2)
  760. helper(Tensor.arange(255), max_ops=2)
  761. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  762. def test_grouped_store_phis(self):
  763. """
  764. float4 acc0 = float4(0.0,0.0,0.0,0.0);
  765. {
  766. acc0 = // ...
  767. }
  768. *((device float4*)(data0+alu2)) = float4(acc0.x,acc0.y,acc0.z,acc0.w);
  769. simplifies to:
  770. *((device float4*)(data0+alu2)) = acc0;
  771. """
  772. x, y = Tensor.randn(64,64), Tensor.randn(64,64)
  773. out = x.matmul(y)
  774. k = helper_linearizer_opt(out)[-1]
  775. # check that the float4 cast collapses
  776. store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE]
  777. for val in store_vals:
  778. assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.VECTORIZE
  779. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  780. def test_grouped_store_values(self):
  781. x = Tensor.randn((4,3,6,6)).realize()
  782. out = x.flip((0,1)).contiguous()
  783. k = helper_linearizer_opt(out)[-1]
  784. store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0]
  785. assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.VECTORIZE
  786. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  787. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  788. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  789. def test_grouped_store_locals_and_globals(self):
  790. x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
  791. out = x@y
  792. opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
  793. Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
  794. k = helper_linearizer_opt(out, opts=[opt])[-1]
  795. def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
  796. local_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
  797. global_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_GLOBAL for x in get_recursive(u.src[0]))]
  798. barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
  799. # check that the float4 cast collapses for all stores
  800. for store in local_stores+global_stores:
  801. assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.VECTORIZE
  802. # # check the children's vins
  803. # TODO: src ALU are not the same, should it?
  804. # assert barrier.src == tuple(local_stores)
  805. assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1
  806. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  807. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  808. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  809. def test_grouped_store_local_only(self):
  810. x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
  811. r = (x@y).relu()
  812. k = helper_linearizer_opt(r)[-1]
  813. stores = [u for u in k.uops if u.op is UOps.STORE]
  814. # the float4 value stores directly in lds and we skip upcast
  815. assert stores[0].src[-1].dtype == dtypes.float.vec(4)
  816. assert stores[0].src[-1].op is not UOps.VECTORIZE
  817. # the global store doesn't change
  818. assert stores[1].src[2].dtype == dtypes.float
  819. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  820. def test_skip_unmatching_upcasts(self):
  821. Tensor.manual_seed(0)
  822. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
  823. opt = [
  824. Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
  825. Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)
  826. ]
  827. k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1]
  828. out = [u for u in k.uops if u.op is UOps.STORE][0]
  829. assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4)
  830. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  831. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
  832. @unittest.expectedFailure # this will require compaction of BinaryOps.ADD
  833. def test_skip_unmatching_upcasts_with_gep(self):
  834. Tensor.manual_seed(0)
  835. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
  836. opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
  837. Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
  838. Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
  839. k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1]
  840. out = [u for u in k.uops if u.op is UOps.STORE][0]
  841. assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(2)
  842. @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
  843. class TestFloat4(unittest.TestCase):
  844. @staticmethod
  845. def count_float4(k):
  846. return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
  847. len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(4)]))
  848. # TODO: express opts below as auto opts
  849. def test_float4_basic(self):
  850. a = Tensor.rand(2, 8).realize()
  851. b = Tensor.rand(2, 8).realize()
  852. c = a + b
  853. s = create_schedule([c.lazydata])[0]
  854. k = Kernel(s.ast)
  855. k.hand_coded_optimizations()
  856. k.linearize()
  857. assert TestFloat4.count_float4(k) == (2, 1)
  858. def test_float4_multidim(self):
  859. a = Tensor.rand(2, 8).realize()
  860. b = Tensor.rand(2, 8).realize()
  861. c = a + b
  862. s = create_schedule([c.lazydata])[0]
  863. k = Kernel(s.ast)
  864. k.shift_to(0, 4) # float4 dimension
  865. k.shift_to(0, 2, insert_before=k.shape_len-1)
  866. k.upcast()
  867. k.upcast()
  868. k.local_dims += 1
  869. k.linearize()
  870. assert TestFloat4.count_float4(k) == (4, 2)
  871. def test_float4_unaligned_load(self):
  872. a = Tensor.rand(9).realize().shrink(((1, 9),))
  873. b = Tensor.rand(9).realize().shrink(((1, 9),))
  874. c = a + b
  875. s = create_schedule([c.lazydata])[0]
  876. k = Kernel(s.ast)
  877. k.hand_coded_optimizations() # implicit trigger float4 dim
  878. k.linearize()
  879. assert TestFloat4.count_float4(k) == (0, 1)
  880. def test_float4_multidim_unaligned_load(self):
  881. a = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
  882. b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
  883. c = a + b
  884. s = create_schedule([c.lazydata])[0]
  885. k = Kernel(s.ast)
  886. k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
  887. k.upcast()
  888. k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
  889. k.upcast()
  890. k.local_dims += 1
  891. k.linearize()
  892. assert TestFloat4.count_float4(k) == (0, 2)
  893. def test_float4_sometimes_unaligned(self):
  894. a = Tensor.rand(1, 1, 8).realize()
  895. b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))
  896. c = a.conv2d(b)
  897. # only the first and last conv dot products are aligned in a, and b is never aligned, so no
  898. # float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
  899. s = create_schedule([c.lazydata])[0]
  900. k = Kernel(s.ast)
  901. k.upcast()
  902. k.linearize()
  903. assert TestFloat4.count_float4(k) == (0, 0)
  904. def test_float4_multidim_sometimes_unaligned(self):
  905. a = Tensor.rand(1, 1, 7).realize()
  906. b = Tensor.rand(1, 1, 5).realize().shrink(((0, 1), (0, 1), (1, 5)))
  907. c = a.conv2d(b)
  908. # the first conv dot product is aligned in a. If we upcast the output and reduce
  909. # dimension, then we could do float4 for only that one set of loads, but we currently
  910. # don't.
  911. s = create_schedule([c.lazydata])[0]
  912. k = Kernel(s.ast)
  913. k.upcast()
  914. k.upcast()
  915. k.linearize()
  916. assert TestFloat4.count_float4(k) == (0, 1)
  917. def test_float4_noncontiguous(self):
  918. a = Tensor.rand(4, 2).realize()
  919. b = Tensor.rand(4, 2).realize()
  920. c = a + b
  921. # we will upcast the top axis of sz 4. they should not be coalesced into float4,
  922. # since the top axis is not contiguous.
  923. s = create_schedule([c.lazydata])[0]
  924. k = Kernel(s.ast)
  925. k.shift_to(0, 4, top=True) # top axes are float4 axes
  926. k.upcast()
  927. k.linearize()
  928. assert TestFloat4.count_float4(k) == (0, 0)
  929. def test_float4_expand(self):
  930. a = Tensor.rand(9).realize().shrink(((1, 9),))
  931. b = Tensor.rand(2).realize().reshape((2, 1)).expand((2,4)).reshape((8,))
  932. c = a + b
  933. # we will upcast the top axis of sz 4. they should not be coalesced into float4,
  934. # since the top axis is not contiguous.
  935. s = create_schedule([c.lazydata])[0]
  936. k = Kernel(s.ast)
  937. k.shift_to(0, 4) # float4 axis
  938. k.upcast()
  939. k.linearize()
  940. assert TestFloat4.count_float4(k) == (0, 1)
  941. def test_float4_heterogeneous(self):
  942. a = Tensor.rand(8).realize()
  943. b = Tensor.rand(9).realize().shrink(((1, 9),))
  944. c = a + b
  945. # should float4 b but not a
  946. s = create_schedule([c.lazydata])[0]
  947. k = Kernel(s.ast)
  948. k.shift_to(0, 4) # float4 axis
  949. k.upcast()
  950. k.linearize()
  951. assert TestFloat4.count_float4(k) == (1, 1)
  952. class TestHandCodedOpts(unittest.TestCase):
  953. def test_masked_upcast(self):
  954. layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)])
  955. layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
  956. s = create_schedule([layer_2.lazydata])[-1]
  957. k = Kernel(s.ast)
  958. k.hand_coded_optimizations()
  959. assert len(k.bufs) == 6 # make sure all ops are done in one kernel
  960. # masked upcast should upcast masked axis of size 7
  961. # masked upcast should not upcast large (20) last axis
  962. # float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous
  963. assert k.upcasted == 1 and k.full_shape[-1] == 7
  964. def test_masked_upcast_wino(self):
  965. monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
  966. s = create_schedule([monster.lazydata])[-1]
  967. k = Kernel(s.ast)
  968. k.hand_coded_optimizations()
  969. assert len(k.bufs) == 37 # make sure all ops are done in one kernel
  970. # should upcast the two Tensor.stacks
  971. assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
  972. def test_masked_upcast_wino_full(self):
  973. with Context(WINO=1):
  974. x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
  975. out = Tensor.conv2d(x,w, padding=1)
  976. upcasts = []
  977. wino_schedule = create_schedule([out.lazydata])
  978. # collect upcasts of tile transform kernels
  979. for i, si in enumerate(wino_schedule):
  980. k = Kernel(si.ast)
  981. k.hand_coded_optimizations()
  982. if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
  983. if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
  984. upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
  985. assert len(upcasts) == 3 # 3 transformation matrices
  986. assert len(wino_schedule) <= 4 # 4 kernels
  987. # this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess
  988. assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
  989. out.mean().backward()
  990. backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
  991. for si in backward_schedule:
  992. k = Kernel(si.ast)
  993. k.hand_coded_optimizations()
  994. k.linearize()
  995. if len(k.bufs) < 20: continue # not a tile transform kernel
  996. # heuristic number to make sure that at least some upcasts but not too many upcasts are being done
  997. assert 6 <= prod(k.full_shape[k.shape_len - k.upcasted:k.shape_len]) <= 216
  998. assert len(backward_schedule) <= 13 # just the current number, but it could be better
  999. def test_masked_upcast_many(self):
  1000. layer_1 = Tensor.cat(Tensor.rand(3, 4), Tensor.rand(4, 4))
  1001. layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4))
  1002. layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4))
  1003. k = helper_linearizer_opt(layer_3)[-1]
  1004. assert len(k.bufs) == 5 # make sure all ops are done in one kernel
  1005. # check that we don't do too many upcasts
  1006. assert prod(k.full_shape[k.shape_len-k.upcasted:k.shape_len]) <= 49
  1007. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1008. def test_matvec(self):
  1009. N = 128
  1010. a = Tensor.rand(1, N).realize()
  1011. b = Tensor.rand(N, N).realize()
  1012. c = a @ b
  1013. k = helper_linearizer_opt(c)[-1]
  1014. assert k.group_for_reduces == 1
  1015. assert k.local_dims == 1
  1016. assert k.upcasted == 1
  1017. def helper_linearizer_ast(ast:Union[Tuple[LazyOp, ...], LazyOp], inputs:List[Tensor], *args, **kwargs):
  1018. if not isinstance(ast, LazyOp): ast = LazyOp(MetaOps.KERNEL, ast)
  1019. inbufs = [x.lazydata.buffer for x in inputs]
  1020. outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
  1021. return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
  1022. def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs):
  1023. realized_ast, real_bufs = helper_realized_ast(r)
  1024. return _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs)
  1025. def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts=[],
  1026. apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Kernel]:
  1027. lins: List[Kernel] = []
  1028. outbufs = [(real_bufs[i], lop.arg.st.shape) for i,lop in enumerate(realized_ast.src)]
  1029. def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
  1030. def check_opt(opts, create_k, expected_color_size):
  1031. k = create_k()
  1032. lins.append(k)
  1033. if apply_tc:
  1034. assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered"
  1035. else:
  1036. for opt in opts:
  1037. k.apply_opt(opt)
  1038. if expected_color_size is not None:
  1039. assert (cs:=list(zip(k.colors(), k.full_shape))) == expected_color_size, f"expected={expected_color_size} got={cs}"
  1040. prg = get_prg(k)
  1041. for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
  1042. prg.exec(real_bufs)
  1043. for i, (buf,shape) in enumerate(outbufs):
  1044. np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
  1045. # Get baseline if it is not provided, which is not optimized at all.
  1046. k = Kernel(realized_ast)
  1047. lins.append(k)
  1048. prg = get_prg(k)
  1049. prg.exec(real_bufs)
  1050. if len(wanna_output) == 0: wanna_output = [np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape).copy() for buf,shape in outbufs]
  1051. else:
  1052. for i, (buf,shape) in enumerate(outbufs):
  1053. np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
  1054. # Check correctness of handcoded optimiztions.
  1055. k = Kernel(realized_ast)
  1056. lins.append(k)
  1057. k.hand_coded_optimizations()
  1058. prg = get_prg(k)
  1059. for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
  1060. prg.exec(real_bufs)
  1061. for i, (buf,shape) in enumerate(outbufs):
  1062. np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)).reshape(shape), wanna_output[i], atol=atol, rtol=rtol)
  1063. for i, x in enumerate(opts): # Check custom transformations if any.
  1064. check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
  1065. return lins
  1066. # creates a back-to-back multi reduce AST by merging r0 and r1.
  1067. # TODO: delete once we can schedule multi reduce
  1068. def _temp_create_multireduce_ast(r0:Tensor, r1:Tensor, replace_idxs:Dict[int,Tensor]={}, \
  1069. merge=lambda r0,r1: LazyOp(BinaryOps.ADD, (r0, r1))) -> Tuple[LazyOp, ...]:
  1070. assert len(s0:=r0.schedule()) == 1 and len(s1:=r1.schedule()) == 1, "inputs should be realized"
  1071. assert all({idx:replace_idxs[idx] is r0 or replace_idxs[idx] is r1 for idx in replace_idxs}.values()), "replace idxs should be in {{r0, r1}}"
  1072. op0, op1 = s0[0].ast.src[0].src[0], s1[0].ast.src[0].src[0]
  1073. _replace_idxs = {idx:(op0 if replace_idxs[idx] is r0 else op1) for idx in replace_idxs}
  1074. def _deep_replace(op:LazyOp, offset=0):
  1075. if op.op is BufferOps.LOAD:
  1076. if op.arg.idx+offset in _replace_idxs: return _replace_idxs[op.arg.idx+offset]
  1077. else: arg = MemBuffer(op.arg.idx+offset, op.arg.dtype, op.arg.st)
  1078. else: arg = op.arg
  1079. return LazyOp(op.op, tuple(_deep_replace(x, offset) for x in op.src), arg)
  1080. # limitation: r0 and r1 cannot share inputs.
  1081. op0 = _deep_replace(op0, 0)
  1082. op0_loads = len([x for x in op0.lazyops if x.op is BufferOps.LOAD])
  1083. out = merge(op0, _deep_replace(op1, op0_loads))
  1084. # limitation: only tests single output
  1085. op = LazyOp(BufferOps.STORE, (out, ), MemBuffer(0, s0[-1].ast.src[-1].arg.dtype, s0[-1].ast.src[-1].arg.st))
  1086. if DEBUG >= 3: print_tree(op)
  1087. return op,
  1088. def check_fused_tc_opt(tc:TensorCore, r0:Tensor, r1:Tensor, inputs:List[Tensor]):
  1089. ast = _temp_create_multireduce_ast(r0, r1)
  1090. (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
  1091. helper_linearizer_ast(ast, inputs, [
  1092. [],
  1093. [Opt(OptOps.UPCAST, 0, 4)],
  1094. [Opt(OptOps.UPCAST, 1, 4)],
  1095. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
  1096. [Opt(OptOps.UNROLL, 0, 2)], # check unroll
  1097. [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals
  1098. [Opt(OptOps.LOCAL, 0, 4)], # check local
  1099. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local
  1100. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
  1101. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
  1102. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
  1103. [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations
  1104. [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
  1105. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)],
  1106. [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)],
  1107. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
  1108. # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
  1109. ], apply_tc=True, atol=atol, rtol=rtol)
  1110. class TestKernelOpts(unittest.TestCase):
  1111. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1112. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1113. def test_local_and_grouped_reduce(self):
  1114. N = 128
  1115. Tensor.manual_seed(1882)
  1116. a = Tensor.rand(4, 4, N, N)
  1117. b = Tensor.rand(4, 4, N)
  1118. r = (b.sqrt() + ((a+1).sum(axis=3).exp()))
  1119. helper_linearizer_opt(r, [
  1120. [Opt(OptOps.LOCAL, 0, 2)],
  1121. [Opt(OptOps.LOCAL, 0, 8)],
  1122. [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals
  1123. [Opt(OptOps.GROUPTOP, 0, 2)],
  1124. [Opt(OptOps.GROUPTOP, 0, 32)],
  1125. [Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce
  1126. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)],
  1127. [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)],
  1128. [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)],
  1129. # Checking how it works with locals + grouped reduce
  1130. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)],
  1131. # Checking how it works with locals + grouped reduce + upcasts
  1132. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)],
  1133. # many local + many group
  1134. [Opt(OptOps.GROUP, 0, 2)] * 4,
  1135. [Opt(OptOps.LOCAL, 0, 2)] * 4,
  1136. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUP, 0, 2)] * 4,
  1137. ])
  1138. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1139. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1140. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1141. @unittest.skip("parallel reduce")
  1142. def test_local_and_grouped_reduce_multireduce(self):
  1143. N = 128
  1144. Tensor.manual_seed(1882)
  1145. a = Tensor.rand(4, 4, N, N).realize()
  1146. b = Tensor.rand(4, 4, N).realize()
  1147. # TODO: this isn't the best AST, it's always math.inf
  1148. r0 = (b.sqrt() + ((a+1).sum(axis=3).exp()))
  1149. c = Tensor.rand(4, 4, N, N).realize()
  1150. d = Tensor.rand(4, 4, N).realize()
  1151. r1 = (d.sqrt() + ((c+1).sum(axis=3).exp()))
  1152. ast = _temp_create_multireduce_ast(r0, r1)
  1153. helper_linearizer_ast(ast, [b, a, d, c], [
  1154. [Opt(OptOps.LOCAL, 0, 2)],
  1155. [Opt(OptOps.LOCAL, 0, 8)],
  1156. [Opt(OptOps.LOCAL, 0, 16)], # Checking how it works with locals
  1157. [Opt(OptOps.GROUPTOP, 0, 2)],
  1158. [Opt(OptOps.GROUPTOP, 0, 32)],
  1159. [Opt(OptOps.GROUPTOP, 0, 64)], # Checking how it works with grouped reduce
  1160. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2)],
  1161. [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.GROUPTOP, 0, 16)],
  1162. [Opt(OptOps.LOCAL, 0, 32), Opt(OptOps.GROUPTOP, 0, 2)],
  1163. # Checking how it works with locals + grouped reduce
  1164. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 64)],
  1165. # Checking how it works with locals + grouped reduce + upcasts
  1166. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.UNROLL, 1, 4)],
  1167. ])
  1168. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1169. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1170. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1171. @unittest.skip("AST has implicit movement ops")
  1172. def test_atomic_store_multireduce(self):
  1173. # reducops will need to use the local buffer to load the result of a local reduce into every thread, barriers are needed on both sides
  1174. # of the load to ensure 1) the correct value is in the local buffer and 2) the value isn't overwritten by the next reduceop
  1175. N = 512
  1176. Tensor.manual_seed(1882)
  1177. a,b = Tensor.rand(4,4,N).realize(), Tensor.rand(4,4,N).realize()
  1178. r0,r1 = a.sum(-1), b.sum(-1)
  1179. ast = _temp_create_multireduce_ast(r0, r1)
  1180. lins = helper_linearizer_ast(ast, [a,b], [[Opt(OptOps.GROUP, 0, 2)]])
  1181. # sequential
  1182. a,b = Tensor.rand(4,4,N).realize(), Tensor.rand(4,4,N).realize()
  1183. dummy = Tensor.rand(4,4,1).realize()
  1184. r0,r1 = (a-dummy).sum(-1), b.sum(-1)
  1185. ast = _temp_create_multireduce_ast(r0, r1, replace_idxs={2:r1}, merge=lambda r0,_: r0)
  1186. lins += helper_linearizer_ast(ast, [a], [[Opt(OptOps.GROUP, 0, 2)]])
  1187. for k in lins:
  1188. seen_bar = False
  1189. for u in k.uops:
  1190. if u.op is UOps.BARRIER:
  1191. assert not seen_bar, "redudant barrier"
  1192. seen_bar = True
  1193. elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
  1194. @unittest.skip("TODO: broken")
  1195. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1196. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1197. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1198. def test_atomic_store_unrolled_multireduce(self):
  1199. # unrolled local dim - causes stores for local reductions to pool at the top of the kernel, overwriting eachother
  1200. Tensor.manual_seed(1882)
  1201. a,b = Tensor.rand(4,).realize(), Tensor.rand(4,).realize()
  1202. r0,r1 = a.sum(), b.sum()
  1203. ast = _temp_create_multireduce_ast(r0, r1)
  1204. lins = helper_linearizer_ast(ast, [a,b], [
  1205. [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.GROUP, 0, 2)]
  1206. ])
  1207. for k in lins:
  1208. seen_bar = False
  1209. for u in k.uops:
  1210. if u.op is UOps.BARRIER:
  1211. assert not seen_bar, "redudant barrier"
  1212. seen_bar = True
  1213. elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
  1214. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1215. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1216. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1217. @unittest.skip("AST has implicit movement ops")
  1218. def test_atomic_store_nested_range_multireduce(self):
  1219. # nested ranges
  1220. Tensor.manual_seed(1882)
  1221. a,b = Tensor.rand(6, ).realize(), Tensor.rand(6, ).realize()
  1222. r0,r1 = a.reshape(6, 1).expand(6, 3).sum(), b.reshape(6, 1).expand(6, 3).sum()
  1223. ast = _temp_create_multireduce_ast(r0, r1)
  1224. lins = helper_linearizer_ast(ast, [a,b], [
  1225. [Opt(OptOps.GROUP, 0, 2)],[Opt(OptOps.GROUP, 1, 3)],
  1226. [Opt(OptOps.GROUP, 1, 3), Opt(OptOps.GROUP, 0, 2)],
  1227. [Opt(OptOps.UNROLL, 0, 2)],[Opt(OptOps.UNROLL, 1, 3)],
  1228. [Opt(OptOps.GROUP, 0, 2), Opt(OptOps.UNROLL, 0, 2)],
  1229. [Opt(OptOps.GROUP, 1, 3), Opt(OptOps.UNROLL, 1, 3)],
  1230. ])
  1231. for k in lins:
  1232. seen_bar = False
  1233. for u in k.uops:
  1234. if u.op is UOps.BARRIER:
  1235. assert not seen_bar, "redudant barrier"
  1236. seen_bar = True
  1237. elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
  1238. def test_upcasts(self):
  1239. N = 16
  1240. Tensor.manual_seed(1772)
  1241. a = Tensor.rand(N, N)
  1242. b = Tensor.rand(N, N)
  1243. r = (a+b).sqrt() * ((a+1).exp())
  1244. helper_linearizer_opt(r, [
  1245. [Opt(OptOps.UPCAST, 0, 2)],
  1246. [Opt(OptOps.UPCAST, 0, 4)],
  1247. [Opt(OptOps.UPCAST, 0, 8)], # Checking how it works with upcasts
  1248. ])
  1249. def test_full_upcast(self):
  1250. Tensor.manual_seed(1772)
  1251. a = Tensor.rand(4)
  1252. b = Tensor.rand(4)
  1253. r = (a+b).sqrt() * ((a+1).exp())
  1254. helper_linearizer_opt(r, [
  1255. [Opt(OptOps.UPCAST, 0, 4)], # Checking how it works with upcasts
  1256. ])
  1257. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1258. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1259. def test_matmul(self):
  1260. N = 128
  1261. Tensor.manual_seed(1552)
  1262. a = Tensor.rand(N, N)
  1263. b = Tensor.rand(N, N)
  1264. r = a@b
  1265. helper_linearizer_opt(r, [
  1266. [Opt(OptOps.UPCAST, 0, 2)],
  1267. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts
  1268. [Opt(OptOps.LOCAL, 0, 2)],
  1269. [Opt(OptOps.LOCAL, 1, 32)],
  1270. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)],
  1271. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)],
  1272. [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals
  1273. [Opt(OptOps.GROUPTOP, 0, 2)],
  1274. [Opt(OptOps.GROUPTOP, 0, 32)],
  1275. [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce
  1276. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)],
  1277. [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)],
  1278. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce
  1279. # Checking all together
  1280. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4),
  1281. Opt(OptOps.UPCAST, 1, 2)],
  1282. # Full global upcast + local
  1283. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)],
  1284. ])
  1285. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1286. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1287. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1288. @unittest.skip("AST has implicit movement ops")
  1289. def test_matmul_multireduce(self):
  1290. N = 128
  1291. Tensor.manual_seed(1552)
  1292. a = Tensor.rand(N, N).realize()
  1293. b = Tensor.rand(N, N).realize()
  1294. r0 = a@b
  1295. c = Tensor.rand(N, N).realize()
  1296. d = Tensor.rand(N, N).realize()
  1297. r1 = c@d
  1298. ast = _temp_create_multireduce_ast(r0, r1)
  1299. helper_linearizer_ast(ast, [a, b, c, d], [
  1300. [Opt(OptOps.UPCAST, 0, 2)],
  1301. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # Checking how it works with upcasts
  1302. [Opt(OptOps.LOCAL, 0, 2)],
  1303. [Opt(OptOps.LOCAL, 1, 32)],
  1304. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)],
  1305. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 32)],
  1306. [Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.LOCAL, 1, 8)], # Checking how it works with locals
  1307. [Opt(OptOps.GROUPTOP, 0, 2)],
  1308. [Opt(OptOps.GROUPTOP, 0, 32)],
  1309. [Opt(OptOps.GROUPTOP, 0, 32), Opt(OptOps.UNROLL, 0, 4)], # Checking how it works with grouped_reduce
  1310. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 32)],
  1311. [Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 32)],
  1312. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 8), Opt(OptOps.GROUPTOP, 0, 4)], # Checking how it works with local+grouped_reduce
  1313. # Checking all together
  1314. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4),
  1315. Opt(OptOps.UPCAST, 1, 2)],
  1316. # Full global upcast + local
  1317. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 8)],
  1318. ], wanna_output=[(a.numpy()@b.numpy()+c.numpy()@d.numpy()).flatten()])
  1319. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1320. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1321. def test_double_reduce(self):
  1322. N = 128
  1323. Tensor.manual_seed(1552)
  1324. a = Tensor.rand(8, N, 8, N)
  1325. r = a.sum(axis=(1,3))
  1326. helper_linearizer_opt(r, [
  1327. # openCL / GPU=1 is 256 max threads
  1328. [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)],
  1329. [Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce.
  1330. [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
  1331. [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)],
  1332. [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces.
  1333. [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)],
  1334. [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts.
  1335. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)],
  1336. # Checking how it works with 2 grouped_reduces + upcasts + locals.
  1337. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)],
  1338. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)],
  1339. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
  1340. Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals.
  1341. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
  1342. Opt(OptOps.UPCAST, 0, 2)], # No globals
  1343. ])
  1344. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1345. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1346. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1347. @unittest.skip("AST has implicit movement ops")
  1348. def test_double_reduce_multireduce(self):
  1349. N = 128
  1350. Tensor.manual_seed(1552)
  1351. a = Tensor.rand(8, N, 8, N).realize()
  1352. r0 = a.sum(axis=(1,3))
  1353. b = Tensor.rand(8, N, 8, N).realize()
  1354. r1 = b.sum(axis=(1,3))
  1355. ast = _temp_create_multireduce_ast(r0, r1)
  1356. helper_linearizer_ast(ast, [a, b], [
  1357. # openCL / GPU=1 is 256 max threads
  1358. [Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)],
  1359. [Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce.
  1360. [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
  1361. [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2)],
  1362. [Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 64)], # Checking how it works with 2 grouped_reduces.
  1363. [Opt(OptOps.GROUPTOP, 0, 16), Opt(OptOps.GROUPTOP, 1, 2), Opt(OptOps.UNROLL, 0, 4)],
  1364. [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 2, 4)], # Checking how it works with 2 grouped_reduces + upcasts.
  1365. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4)],
  1366. # Checking how it works with 2 grouped_reduces + upcasts + locals.
  1367. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 32), Opt(OptOps.UNROLL, 1, 4)],
  1368. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2)],
  1369. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 1, 2), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
  1370. Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UNROLL, 1, 4)], # Checking how it works with 2 grouped_reduces + upcasts + locals.
  1371. [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4), Opt(OptOps.GROUPTOP, 0, 4), Opt(OptOps.GROUPTOP, 1, 4), Opt(OptOps.UPCAST, 0, 2),
  1372. Opt(OptOps.UPCAST, 0, 2)], # No globals
  1373. ], wanna_output=[(a.numpy().sum(axis=(1, 3))+b.numpy().sum(axis=(1, 3))).flatten()])
  1374. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  1375. def test_invalid_tensor_core_extra_opts(self):
  1376. N = 128
  1377. Tensor.manual_seed(1552)
  1378. a = Tensor.rand(N, N)
  1379. b = Tensor.rand(N, N)
  1380. realized_ast, _ = helper_realized_ast(a@b)
  1381. invalid_opts = [
  1382. [Opt(OptOps.LOCAL, 2, 2)],
  1383. [Opt(OptOps.UPCAST, 2, 2)],
  1384. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 2, 2)],
  1385. ]
  1386. for x in invalid_opts:
  1387. k = Kernel(realized_ast)
  1388. with self.assertRaises(AssertionError):
  1389. assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
  1390. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  1391. def test_buf_index_not_found_tensor_core(self):
  1392. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPNE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
  1393. k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
  1394. with self.assertRaises(KernelOptError):
  1395. k.apply_opt(Opt(OptOps.TC, 0, 1))
  1396. @unittest.skip("parallel tensor cores")
  1397. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  1398. def test_invalid_fused_tensor_core(self):
  1399. Tensor.manual_seed(1552)
  1400. for tc in Device[Device.DEFAULT].renderer.tensor_cores:
  1401. if tc.dtype_in == dtypes.bfloat16: continue
  1402. M, N, K = 12, 8, 30
  1403. a, b = Tensor.rand(M, K, dtype=tc.dtype_in).realize(), Tensor.rand(K, N, dtype=tc.dtype_in).realize()
  1404. r0 = a.matmul(b, acc_dtype=tc.dtype_out)
  1405. M, N, K = 16, 8, 33
  1406. c, d = Tensor.rand(M, K, dtype=tc.dtype_in).realize(), Tensor.rand(K, N, dtype=tc.dtype_in).realize()
  1407. r1 = c.matmul(d, acc_dtype=tc.dtype_out)
  1408. ast = _temp_create_multireduce_ast(r0, r1)
  1409. lin = Kernel(ast)
  1410. lin.apply_opt(Opt(op=OptOps.TC, axis=0, amt=2))
  1411. lin.linearize()
  1412. result = compare_linearizer(lin)
  1413. assert result[0] == "COMPARE_ERROR"
  1414. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  1415. def test_tensor_core_opts(self):
  1416. N = 128
  1417. Tensor.manual_seed(1552)
  1418. for tc in Device[Device.DEFAULT].renderer.tensor_cores:
  1419. # bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
  1420. if tc.dtype_in == dtypes.bfloat16: continue
  1421. a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
  1422. r = a.matmul(b, acc_dtype=tc.dtype_out)
  1423. (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
  1424. helper_linearizer_opt(r, [
  1425. [],
  1426. [Opt(OptOps.UPCAST, 0, 4)],
  1427. [Opt(OptOps.UPCAST, 1, 4)],
  1428. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts
  1429. [Opt(OptOps.UNROLL, 0, 2)], # check unroll
  1430. [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals
  1431. [Opt(OptOps.LOCAL, 0, 4)], # check local
  1432. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local
  1433. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)],
  1434. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)],
  1435. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)],
  1436. [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations
  1437. [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
  1438. [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)],
  1439. [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)],
  1440. [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)],
  1441. # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC)
  1442. ], apply_tc=True, atol=atol, rtol=rtol)
  1443. @unittest.skip("parallel tensor cores")
  1444. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  1445. def test_fused_tensor_core_simple(self):
  1446. N = 64
  1447. Tensor.manual_seed(1552)
  1448. for tc in Device[Device.DEFAULT].renderer.tensor_cores:
  1449. if tc.dtype_in == dtypes.bfloat16: continue
  1450. [a, b, c, d] = [Tensor.randn(N, N, dtype=tc.dtype_in).realize() for _ in range(4)]
  1451. r0 = a.matmul(b, acc_dtype=tc.dtype_out)
  1452. r1 = c.matmul(d, acc_dtype=tc.dtype_out)
  1453. check_fused_tc_opt(tc, r0, r1, [a, b, c, d])
  1454. @unittest.skip("parallel tensor cores")
  1455. @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
  1456. def test_fused_tensor_core_permuted(self):
  1457. N = 64
  1458. Tensor.manual_seed(1552)
  1459. for tc in Device[Device.DEFAULT].renderer.tensor_cores:
  1460. if tc.dtype_in == dtypes.bfloat16: continue
  1461. # one permuted
  1462. [a, b, c, d] = [Tensor.randn(N, N, dtype=tc.dtype_in).realize() for _ in range(4)]
  1463. r0 = a.matmul(b, acc_dtype=tc.dtype_out)
  1464. r1 = c.T.matmul(d, acc_dtype=tc.dtype_out)
  1465. check_fused_tc_opt(tc, r0, r1, [a, b, c, d])
  1466. # both permuted
  1467. r0 = a.T.matmul(b, acc_dtype=tc.dtype_out)
  1468. r1 = c.T.matmul(d, acc_dtype=tc.dtype_out)
  1469. check_fused_tc_opt(tc, r0, r1, [a, b, c, d])
  1470. def test_padto_matmul(self):
  1471. if CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims")
  1472. N = 17 * 17
  1473. Tensor.manual_seed(289)
  1474. a = Tensor.rand(N, N)
  1475. b = Tensor.rand(N, N)
  1476. helper_linearizer_opt(a@b, [
  1477. [Opt(OptOps.PADTO, 0, 32)],
  1478. [Opt(OptOps.PADTO, 1, 32)],
  1479. [Opt(OptOps.PADTO, 2, 32)],
  1480. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)],
  1481. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
  1482. # can optimize further post PADTO
  1483. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),],
  1484. ])
  1485. def test_padto_upcasted_not_ok(self):
  1486. N = 4
  1487. a = Tensor.rand(N, N)
  1488. b = Tensor.rand(N, N)
  1489. helper_linearizer_opt(a@b, [
  1490. [Opt(OptOps.UPCAST, 0, 0)],
  1491. [Opt(OptOps.UPCAST, 1, 0)],
  1492. [Opt(OptOps.UNROLL, 0, 0)],
  1493. [Opt(OptOps.PADTO, 0, 8)],
  1494. [Opt(OptOps.PADTO, 1, 8)],
  1495. [Opt(OptOps.PADTO, 2, 8)],
  1496. ])
  1497. with self.assertRaises(KernelOptError):
  1498. helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
  1499. with self.assertRaises(KernelOptError):
  1500. helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 2, 8)]])
  1501. with self.assertRaises(KernelOptError):
  1502. helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
  1503. def test_padto_sum_ok(self):
  1504. N = 18 * 18
  1505. # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
  1506. a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100
  1507. b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17)))
  1508. helper_linearizer_opt(a.sum(0), [
  1509. [Opt(OptOps.PADTO, 0, 32)],
  1510. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
  1511. ])
  1512. helper_linearizer_opt(a.sum(1), [
  1513. [Opt(OptOps.PADTO, 0, 32)],
  1514. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
  1515. ])
  1516. # can pad sum reduce axis if there's no unsafe ops prior to sum
  1517. for axis in (0, 1):
  1518. helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
  1519. helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
  1520. helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
  1521. helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
  1522. helper_linearizer_opt(b.sum(acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
  1523. helper_linearizer_opt(b.sum(0, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
  1524. helper_linearizer_opt(b.sum(1, acc_dtype=dtypes.bool), [[Opt(OptOps.PADTO, axis, 32)],])
  1525. # having unsafe ops after sum is fine
  1526. helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],])
  1527. helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],])
  1528. def test_padto_sum_not_ok(self):
  1529. N = 18 * 18
  1530. # NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
  1531. a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp()
  1532. # exp is not safe to pad
  1533. with self.assertRaises(KernelOptError):
  1534. helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],])
  1535. with self.assertRaises(KernelOptError):
  1536. helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
  1537. b = a < -1
  1538. # lt is not safe to pad
  1539. with self.assertRaises(KernelOptError):
  1540. helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],])
  1541. with self.assertRaises(KernelOptError):
  1542. helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
  1543. def test_padto_max(self):
  1544. N = 18 * 18
  1545. # NOTE: this setup prevents 17 * 17 contiguous merged into one axis
  1546. a = -Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100
  1547. helper_linearizer_opt(a.max(0), [
  1548. [Opt(OptOps.PADTO, 0, 32)],
  1549. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
  1550. ])
  1551. helper_linearizer_opt(a.max(1), [
  1552. [Opt(OptOps.PADTO, 0, 32)],
  1553. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
  1554. ])
  1555. # cannot pad max kernel on reduce
  1556. with self.assertRaises(KernelOptError):
  1557. helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
  1558. with self.assertRaises(KernelOptError):
  1559. helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
  1560. def test_padto_where(self):
  1561. Tensor.manual_seed(0)
  1562. N = 17 * 17
  1563. a = (Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1).where(1, 0)
  1564. helper_linearizer_opt(a.max(0), [
  1565. [Opt(OptOps.PADTO, 0, 32)],
  1566. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
  1567. ])
  1568. def test_padto_where_multioutput(self):
  1569. Tensor.manual_seed(0)
  1570. N = 17 * 17
  1571. r = Tensor.randn(N, N).realize().max(axis=0, keepdim=True) > 1
  1572. a0 = r.where(1, 0)
  1573. a1 = r.where(2, 0)
  1574. helper_linearizer_opt([a0.max(0), a1.max(0)], [
  1575. [Opt(OptOps.PADTO, 0, 32)],
  1576. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
  1577. ])
  1578. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1579. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1580. def test_padto_group(self):
  1581. Tensor.manual_seed(0)
  1582. ld0 = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)))) # noqa: E501
  1583. ld1 = LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))) # noqa: E501
  1584. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(ld0, ld1)),), arg=(0, 2, 4, 6)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
  1585. data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
  1586. data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
  1587. helper_linearizer_ast((ast, ), [data1, data2], opts=[
  1588. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
  1589. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
  1590. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
  1591. ])
  1592. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1593. @unittest.skip("AST has implicit movement ops")
  1594. def test_padto_sum_multireduce(self):
  1595. Tensor.manual_seed(0)
  1596. N = 17
  1597. x = Tensor.rand(N, N).realize()
  1598. opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],]
  1599. x_ld = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((N, N))))
  1600. def ast(axis, output_shape):
  1601. r0 = LazyOp(ReduceOps.SUM, (x_ld,), axis)
  1602. r1 = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.ADD, (x_ld, LazyOp(op=UnaryOps.NEG, src=(r0,), arg=None)),),), axis)
  1603. return LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape(output_shape))),
  1604. helper_linearizer_ast(ast((0, ), (1, 17)), [x], opts=opts, wanna_output=[(x.numpy()-x.numpy().sum(axis=0,keepdims=True)).sum(0)])
  1605. helper_linearizer_ast(ast((1, ), (17, 1)), [x], opts=opts, wanna_output=[(x.numpy()-x.numpy().sum(axis=1,keepdims=True)).sum(1)])
  1606. expected = (x.numpy()-x.numpy().sum(axis=0,keepdims=True)).sum(0)
  1607. helper_linearizer_ast(ast((0, ), (1, 17)), [x], opts=[[Opt(OptOps.PADTO, 1, 32)]], wanna_output=[expected])
  1608. op = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(x_ld,LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=ReduceOps.SUM, src=(x_ld,), arg=(0,1)),),arg=None))),), arg=(0,1)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
  1609. helper_linearizer_ast((op,), [x], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[(x.numpy()-x.numpy().sum(keepdims=True)).sum()])
  1610. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1611. @unittest.skip("AST has implicit movement ops")
  1612. def test_padto_max_multireduce(self):
  1613. Tensor.manual_seed(0)
  1614. N = 17
  1615. x = Tensor.rand(N, N).realize()
  1616. opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],]
  1617. x_ld = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, ShapeTracker.from_shape((N, N))))
  1618. def ast(axis, output_shape):
  1619. r0 = LazyOp(ReduceOps.MAX, (x_ld,), axis)
  1620. r1 = LazyOp(ReduceOps.MAX, (LazyOp(BinaryOps.ADD, (x_ld,r0,),),), axis)
  1621. return LazyOp(BufferOps.STORE, (r1, ), MemBuffer(0, dtypes.float, ShapeTracker.from_shape(output_shape))),
  1622. helper_linearizer_ast(ast((0, ), (1, 17)), [x], opts=opts, wanna_output=[(x.numpy()+x.numpy().max(axis=0,keepdims=True)).max(0)])
  1623. helper_linearizer_ast(ast((1, ), (17, 1)), [x], opts=opts, wanna_output=[(x.numpy()+x.numpy().max(axis=1,keepdims=True)).max(1)])
  1624. @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
  1625. @unittest.skip("AST has implicit movement ops")
  1626. def test_padto_where_multireduce(self):
  1627. # we need to make sure the ternary operators nest properly
  1628. N = 17
  1629. x = Tensor.rand(N, N).realize()
  1630. a = Tensor.rand(1, 1).realize()
  1631. b = Tensor.rand(1, 1).realize()
  1632. opts = [[Opt(OptOps.PADTO, 0, 32)],[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],]
  1633. # TODO: these large ASTs are suboptimal but we need this until the scheduler can fuse these
  1634. wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0)
  1635. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),), arg=(1,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),)),), arg=(1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((N,1)))) # noqa: E501
  1636. helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output])
  1637. wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0)
  1638. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),), arg=(0,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),)),), arg=(0,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,N)))) # noqa: E501
  1639. helper_linearizer_ast((ast,), [x,a,b], opts=opts, wanna_output=[wanna_output])
  1640. # pad reduce axis
  1641. helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 1, 32)],], wanna_output=[wanna_output])
  1642. wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0)
  1643. ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.5*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),LazyOp(op=TernaryOps.WHERE, src=(LazyOp(op=BinaryOps.CMPLT, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.75*17, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker.from_shape((N,N)))),), arg=(0,1,)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),)),), arg=(0,1,)),)),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))),)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker.from_shape((1,1)))) # noqa: E501
  1644. helper_linearizer_ast((ast,), [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output.flatten()])
  1645. def test_padto_matmul_multireduce(self):
  1646. if CI and Device.DEFAULT in ["AMD", "NV", "CUDA"]: self.skipTest("super slow on CUDA and AMD because of the big grid dims")
  1647. N = 17 * 17
  1648. Tensor.manual_seed(289)
  1649. a = Tensor.rand(N, N).realize()
  1650. b = Tensor.rand(N, N).realize()
  1651. c = Tensor.rand(N, N).realize()
  1652. d = Tensor.rand(N, N).realize()
  1653. r0 = a@b
  1654. r1 = c@d
  1655. ast = _temp_create_multireduce_ast(r0,r1)
  1656. helper_linearizer_ast(ast, [a,b,c,d], opts=[
  1657. [Opt(OptOps.PADTO, 0, 32)],
  1658. [Opt(OptOps.PADTO, 1, 32)],
  1659. [Opt(OptOps.PADTO, 2, 32)],
  1660. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32)],
  1661. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)],
  1662. # can optimize further post PADTO
  1663. [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),],
  1664. ], wanna_output=[(a.numpy()@b.numpy()+c.numpy()@d.numpy()).reshape(N, N, 1)])
  1665. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
  1666. @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
  1667. def test_color_shapes_with_local(self):
  1668. N = 32
  1669. Tensor.manual_seed(1552)
  1670. a = Tensor.rand(N, N)
  1671. b = Tensor.rand(N, N)
  1672. r = a@b
  1673. opts_shapes = [
  1674. ([Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("red",32)]),
  1675. ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",2),("red",16)]),
  1676. # check to ensure local_dims are stable for full UNROLL of first_reduce
  1677. ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
  1678. ([Opt(OptOps.UNROLL, 0, 0),Opt(OptOps.LOCAL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
  1679. # check behavior for full UNROLL on an existing GROUP
  1680. ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 2)], [("blue",16),("blue",32),("cyan",2),("green",16),("magenta",2)]),
  1681. ([Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.GROUP, 0, 0),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
  1682. ([Opt(OptOps.GROUP, 0, 0),Opt(OptOps.LOCAL, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",16),("blue",32),("cyan",2),("magenta",32)]),
  1683. ([Opt(OptOps.GROUP, 0, 2),Opt(OptOps.UNROLL, 0, 0)], [("blue",32),("blue",32),("red",16),("magenta",2)]),
  1684. ]
  1685. helper_linearizer_opt(r, [x[0] for x in opts_shapes], color_sizes=[x[1] for x in opts_shapes])
  1686. if __name__ == '__main__':
  1687. unittest.main()