test_schedule.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259
  1. # this will be the new test_ops for the next level
  2. # schedule confirms the right things are capable of fusing
  3. # NOTE: this has overlap with external_test_opt.py
  4. import unittest
  5. import numpy as np
  6. from typing import List, Optional, Union
  7. from tinygrad import nn, dtypes
  8. from tinygrad.device import Device
  9. from tinygrad.tensor import Tensor
  10. from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
  11. from tinygrad.helpers import DEBUG, flatten, getenv
  12. from tinygrad.codegen.kernel import Kernel
  13. from tinygrad.engine.graph import print_tree
  14. from tinygrad.engine.schedule import create_schedule
  15. from tinygrad.engine.realize import run_schedule
  16. from test.helpers import is_dtype_supported
  17. from tinygrad.function import Function
  18. from tinygrad.lazy import LazyBuffer, view_supported_devices
  19. class KernelCountException(Exception): pass
  20. def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
  21. if isinstance(t, Tensor): t = [t]
  22. seen = set()
  23. if to_prerealize:
  24. for pre in to_prerealize:
  25. for s in pre.schedule(seen=seen.copy()):
  26. for i,out in enumerate(s.outputs):
  27. seen.add(out)
  28. sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
  29. if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.KERNEL]
  30. if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
  31. if len(sched) != allowed or DEBUG >= 3:
  32. for i, s in enumerate(sched):
  33. print("kernel", i+1)
  34. print_tree(s.ast)
  35. if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
  36. # test the (sink) ops linearize
  37. for s in sched:
  38. if s.ast.op is not MetaOps.KERNEL: continue
  39. l = Kernel(s.ast)
  40. l.hand_coded_optimizations()
  41. l.linearize()
  42. return sched
  43. class TestSchedule(unittest.TestCase):
  44. def test_basic_binop_fusion(self):
  45. a = Tensor.empty(10)
  46. b = Tensor.empty(10)
  47. c = Tensor.empty(10)
  48. d = a+b+c
  49. check_schedule(d, 1)
  50. def test_basic_binop_fusion_deep(self):
  51. a = Tensor.empty(10)
  52. b = Tensor.empty(10)
  53. c = Tensor.empty(10)
  54. d = Tensor.empty(10)
  55. e = a+b+c+d
  56. check_schedule(e, 1)
  57. def test_mulacc_fusion(self):
  58. a = Tensor.empty(10)
  59. b = Tensor.empty(10)
  60. c = (a*b).sum()
  61. check_schedule(c, 1)
  62. def test_mulacc_relu_fusion(self):
  63. a = Tensor.empty(10)
  64. b = Tensor.empty(10)
  65. c = (a*b).sum().relu()
  66. check_schedule(c, 1)
  67. def test_binop_reshape_fusion(self):
  68. a = Tensor.empty(10)
  69. b = Tensor.empty(10)
  70. c = Tensor.empty(5,2)
  71. d = (a+b).reshape(5,2)+c
  72. check_schedule(d, 1)
  73. def test_binop_permute_fusion(self):
  74. a = Tensor.empty(2,5)
  75. b = Tensor.empty(2,5)
  76. c = Tensor.empty(5,2)
  77. d = (a+b).permute(1,0)+c
  78. check_schedule(d, 1)
  79. def test_constants_are_embedded(self):
  80. a = Tensor.empty(3,3) * 2
  81. check_schedule(a, 2, filter_sink=False)
  82. def test_binop_elu_fusion(self):
  83. a = Tensor.empty(10)
  84. b = a.elu()
  85. check_schedule(b, 1)
  86. def test_binop_reshape_reduce_fusion(self):
  87. a = Tensor.empty(100)
  88. b = Tensor.empty(100)
  89. c = (a+b).reshape(10, 10).sum(axis=0, keepdim=True)
  90. check_schedule(c, 1)
  91. def test_reduce_reshape_binop_fusion(self):
  92. a = Tensor.empty(10,10)
  93. b = Tensor.empty(10)
  94. c = a.sum(axis=0) + b
  95. check_schedule(c, 1)
  96. # not pushing permutes through reduces
  97. def test_reduce_permute_binop_fusion(self):
  98. a = Tensor.empty(10,10,10)
  99. b = Tensor.empty(10,10,1)
  100. c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
  101. with self.assertRaises(KernelCountException): check_schedule(c, 1)
  102. def test_binop_early_reshape_reduce_fusion(self):
  103. a = Tensor.empty(100)
  104. b = Tensor.empty(100)
  105. c = Tensor.empty(10,10)
  106. d = ((a+b).reshape(10,10) + c).sum(axis=0)
  107. check_schedule(d, 1)
  108. def test_diamond_folded(self):
  109. a = Tensor.empty(10)
  110. b = Tensor.empty(10)
  111. c = Tensor.empty(10)
  112. d = Tensor.empty(10)
  113. ab = a+b
  114. e = (ab+c) + (ab+d)
  115. check_schedule(e, 1)
  116. def test_cache_binaryop(self):
  117. a = Tensor.empty(10)
  118. b = Tensor.empty(10)
  119. c = a+b
  120. d = a+b
  121. check_schedule(d, 0, [c])
  122. # failing in new lazy
  123. def test_cache_binaryop_reshaped(self):
  124. a = Tensor.empty(10)
  125. b = Tensor.empty(10)
  126. c = a+b
  127. d = a.reshape(10,1)+b.reshape(10,1)
  128. with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
  129. # failing in new lazy
  130. def test_cache_binaryop_transpose(self):
  131. a = Tensor.empty(10,10)
  132. b = Tensor.empty(10,10)
  133. c = (a.T*b.T).T #.contiguous()
  134. d = a*b
  135. with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
  136. def test_cache_two_reduceops(self):
  137. a = Tensor.empty(10)
  138. b = a.sum()
  139. c = a.sum()
  140. bc = b+c
  141. check_schedule(bc, 1)
  142. def test_cache_reduce_parent(self):
  143. x = Tensor.empty(32)
  144. r0 = x.mean(axis=0, keepdim=True)
  145. r1 = (x - r0).sum(axis=0).div(2)
  146. out = r0 + r1
  147. schedule = check_schedule(out, 2)
  148. reduceops = [x for si in schedule for x in si.ast.lazyops if x.op in ReduceOps]
  149. assert len(reduceops) == 2
  150. def test_cache_reduce_multiple_children(self):
  151. x = Tensor.empty(32)
  152. y = Tensor.empty(4, 4)
  153. r0 = x.mean(axis=0, keepdim=True)
  154. r1 = (x - r0).sum(axis=0).div(2)
  155. out0 = r0 + y
  156. out1 = r1 + y
  157. schedule = check_schedule([out0, out1], 4)
  158. reduceops = [x for si in schedule for x in si.ast.lazyops if x.op in ReduceOps]
  159. assert len(reduceops) == 2
  160. def test_fold_double_unary(self):
  161. y = Tensor.empty(2)
  162. out = y.sum(keepdim=True).sqrt().__neg__()
  163. check_schedule(out, 1)
  164. #@unittest.skip("may want to reconsider this")
  165. def test_fold_batchnorm(self):
  166. with Tensor.train():
  167. img = Tensor.empty(1,32,4,4)
  168. bn = nn.BatchNorm2d(32, track_running_stats=False)
  169. out = bn(img)
  170. check_schedule(out, 3)
  171. def test_fold_conv_batchnorm_notrain(self):
  172. with Tensor.train(False):
  173. img = Tensor.empty(1,3,8,8)
  174. c1 = nn.Conv2d(3,32,3)
  175. bn = nn.BatchNorm2d(32, track_running_stats=False)
  176. out = bn(c1(img)).relu()
  177. check_schedule(out, 1, [c1.weight, c1.bias])
  178. def test_fold_conv_batchnorm(self):
  179. with Tensor.train():
  180. img = Tensor.empty(1,3,8,8)
  181. c1 = nn.Conv2d(3,32,3)
  182. bn = nn.BatchNorm2d(32, track_running_stats=False)
  183. out = bn(c1(img)).relu()
  184. check_schedule(out, 4, [c1.weight, c1.bias])
  185. def test_fold_conv_batchnorm_optim(self):
  186. # this is too high
  187. for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
  188. with self.subTest(optim=optim.__name__):
  189. with Tensor.train():
  190. img = Tensor.ones(1,3,4,4)
  191. c1 = nn.Conv2d(3,32,3)
  192. bn = nn.BatchNorm2d(32, track_running_stats=False)
  193. opt = optim(nn.state.get_parameters([c1, bn]))
  194. img_bn = bn(c1(img)).elu().sum()
  195. opt.zero_grad()
  196. img_bn.backward()
  197. check_schedule(opt.schedule_step(), cnt)
  198. def test_fold_conv_relu_backward(self):
  199. c1 = nn.Conv2d(3,16,3, bias=False)
  200. c1.weight.requires_grad = True
  201. # run
  202. img = Tensor.rand(2,3,64,64, requires_grad=True)
  203. c1(img).relu().mean().backward()
  204. # TODO: this should be 4, not 5
  205. # img.grad is requiring two reduces
  206. check_schedule([img.grad, c1.weight.grad], 5)
  207. def test_fold_batchnorm_backward(self):
  208. with Tensor.train():
  209. x = Tensor.empty((2, 16, 8, 8)).contiguous()
  210. bn = nn.BatchNorm2d(16)
  211. bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
  212. fw = bn(x).contiguous_backward().relu().contiguous()
  213. fw.sum().backward()
  214. # TODO: this is too many
  215. check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 10)
  216. def test_fold_conv_relu(self):
  217. c1 = nn.Conv2d(3,16,3)
  218. # run
  219. img = Tensor.ones(2,3,64,64)
  220. out = c1(img).relu()
  221. check_schedule(out, 1, [c1.weight, c1.bias])
  222. def test_fold_conv_relu_alt(self):
  223. img = Tensor.ones(1,4,8,8)
  224. c1 = nn.Conv2d(4, 4, kernel_size=3)
  225. c2 = nn.Conv2d(4, 4, kernel_size=3)
  226. img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
  227. check_schedule(img_conv, 2, [*nn.state.get_parameters(c1), *nn.state.get_parameters(c2), img])
  228. def test_fold_conv_relu_nobias(self):
  229. img = Tensor.ones(1,4,8,8)
  230. c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
  231. c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False)
  232. out = img.sequential([c1, Tensor.relu, c2, Tensor.relu])
  233. check_schedule(out, 2, [c1.weight, c2.weight, img])
  234. def test_fold_conv_elu(self):
  235. c1 = nn.Conv2d(3,16,3)
  236. # run
  237. img = Tensor.rand(2,3,64,64)
  238. out = c1(img).elu()
  239. check_schedule(out, 1, [c1.weight, c1.bias, img])
  240. def test_fold_conv_elu_alt(self):
  241. img = Tensor.ones(1,4,8,8).contiguous()
  242. c1 = nn.Conv2d(4, 4, kernel_size=3)
  243. c2 = nn.Conv2d(4, 4, kernel_size=3)
  244. img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu])
  245. check_schedule(img_conv, 2, [*nn.state.get_parameters(c1), *nn.state.get_parameters(c2), img])
  246. def test_two_sum(self):
  247. img = Tensor.empty(64,64)
  248. x = (img.sum(0) + img.sum(1))
  249. out = x.relu()
  250. del x # is 3 without this
  251. check_schedule(out, 2)
  252. #@unittest.skip("failing in old lazy")
  253. def test_push_permute_through_reshape(self):
  254. a = Tensor.empty(16,16)
  255. b = Tensor.empty(16,16)
  256. c = (a+b).reshape(4,4,4,4).permute(2,3,0,1).contiguous()
  257. check_schedule(c, 1)
  258. #@unittest.skip("failing in old lazy")
  259. def test_push_permute_through_reshape_alt(self):
  260. a = Tensor.empty(4,4,4,4)
  261. b = Tensor.empty(4,4,4,4)
  262. c = (a+b).reshape(16,16).permute(1,0).contiguous()
  263. check_schedule(c, 1)
  264. def test_no_binop_rerun(self):
  265. a = Tensor.empty(16)
  266. b = Tensor.empty(16)
  267. c = a+b
  268. d = (a+b).reshape(16,1)
  269. check_schedule(d, 0, [c])
  270. def test_multi_permute_should_collapse(self):
  271. a = Tensor.empty(4,4,4,4)
  272. b = Tensor.empty(16)
  273. c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b
  274. check_schedule(c, 1)
  275. def test_fancy_reshape_fusion(self):
  276. a = Tensor.empty(10)
  277. b = Tensor.empty(10)
  278. c = a+b
  279. d = a.reshape(10,1)+b.reshape(10,1)
  280. out = c.sum() + d.sum()
  281. with self.assertRaises(KernelCountException): check_schedule(out, 1)
  282. def test_children_dont_push(self):
  283. a = Tensor.empty(10, 10, 1)
  284. b = Tensor.empty(10, 10, 1)
  285. d = (a+b).expand(10, 10, 10)
  286. e = (a+b).permute(2,1,0)
  287. f = d+e
  288. check_schedule(f, 2)
  289. # failing in new lazy
  290. def test_dont_fuse_binops_with_children(self):
  291. a = Tensor.empty(10)
  292. b = Tensor.empty(10)
  293. c = Tensor.empty(10)
  294. keep_me = a+b
  295. e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
  296. d = keep_me+c
  297. with self.assertRaises(KernelCountException): check_schedule(d, 2)
  298. with self.assertRaises(KernelCountException): check_schedule(keep_me, 0, [d])
  299. #@unittest.skip("failing in old lazy")
  300. def test_permute_breaks_fusion(self):
  301. a = Tensor.empty(10, 10, 10)
  302. b = Tensor.empty(10, 10)
  303. c = (a.sum(axis=2) + b).permute(1,0)
  304. d = c.permute(1,0)
  305. check_schedule(d, 1)
  306. def test_some_permute_fusion(self):
  307. a = Tensor.empty(8192, 16)
  308. b = Tensor.empty(1, 16)
  309. d = (a.T + b.expand(8192, 16).T)
  310. c = a + b.expand(8192, 16)
  311. e = d.T
  312. check_schedule(c, 1)
  313. check_schedule(e, 1)
  314. def test_shrink_fuse(self):
  315. a = Tensor.empty(8192, 16)
  316. b = Tensor.empty(8192, 16)
  317. c = a * b
  318. d = Tensor.empty(1, 16)
  319. e = c[0] * d
  320. check_schedule(e, 1)
  321. def test_expand_nofuse(self):
  322. a = Tensor.empty(1, 16)
  323. b = Tensor.empty(1, 16)
  324. c = a * b
  325. d = Tensor.empty(8192, 16)
  326. e = c * d
  327. check_schedule(e, 2)
  328. # this is the failing case in openpilot...it's very simple like this
  329. def test_image_conv_fusion(self):
  330. w1 = Tensor.empty(16, 16, 1, 1)
  331. b1 = Tensor.empty(16)
  332. w2 = Tensor.empty(16, 16, 1, 1)
  333. b2 = Tensor.empty(16)
  334. w3 = Tensor.empty(16, 16, 1, 1)
  335. b3 = Tensor.empty(16)
  336. x = Tensor.empty(1, 16, 32, 32)
  337. x = base = x.image_conv2d(w1, b1)
  338. x = x.image_conv2d(w2, b2) + base
  339. x = x.image_conv2d(w3, b3)
  340. # NOOP, 3 convs, contiguous
  341. with self.assertRaises(KernelCountException): check_schedule(x, 5)
  342. def test_image_conv_fusion_minimal(self):
  343. b1 = Tensor.empty(16)
  344. b2 = Tensor.empty(16)
  345. def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
  346. x = Tensor.empty(16, 32)
  347. x = base = p(x) + b1.reshape(16,1)
  348. x = p(x)
  349. x = x + b2.reshape(16,1)
  350. x = x + base
  351. del base
  352. x = p(x)
  353. check_schedule(x, 4)
  354. def test_image_conv_fusion_more_minimal(self):
  355. b1 = Tensor.empty(16)
  356. def p(x): return x.permute(1,0).contiguous().reshape(32,16,1).expand(32,16,16).sum(axis=2).permute(1,0)
  357. x = Tensor.empty(16, 32)
  358. x = base = p(x) + b1.reshape(16,1)
  359. x = p(x)
  360. del base
  361. check_schedule(x, 3)
  362. def test_resnet_block(self):
  363. old_training = Tensor.training
  364. Tensor.training = False
  365. in_planes, planes = 64, 64
  366. conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
  367. bn1 = nn.BatchNorm2d(planes)
  368. conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
  369. bn2 = nn.BatchNorm2d(planes)
  370. x = Tensor.empty(1, 64, 32, 32)
  371. out = bn1(conv1(x)).relu()
  372. out = bn2(conv2(out))
  373. out = (out + x).relu()
  374. check_schedule(out, 2, [conv1.weight, conv2.weight])
  375. Tensor.training = old_training
  376. def test_contiguous_while_contiguous(self):
  377. x = Tensor.empty(1, 64, 32, 32)
  378. out = x.contiguous()
  379. check_schedule(out, 1, filter_sink=False)
  380. def test_contiguous_while_not_contiguous(self):
  381. x = Tensor.empty(1, 64, 32, 32)
  382. out = x.permute(0,2,3,1).contiguous()
  383. check_schedule(out, 2, filter_sink=False)
  384. def test_fold_with_contiguous(self):
  385. a = Tensor.randn(16, 16, 16).realize()
  386. b = Tensor.randn(16, 16).realize()
  387. c = (a.sum(2).contiguous() + b).contiguous()
  388. check_schedule(c, 2)
  389. def test_double_from(self):
  390. x = Tensor([1,2,3,4])
  391. out = x.to('npy')
  392. check_schedule(out, 0, filter_sink=False)
  393. def test_pow_const_tensor_simplified(self):
  394. x = Tensor([1,2,3,4])
  395. # NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5)
  396. out = x ** Tensor(2)
  397. check_schedule(out, 1)
  398. def test_pow_const_tensor_to_zero(self):
  399. x = Tensor([1,2,3,4])
  400. out = x ** Tensor(0)
  401. # NOTE: this is ConstBuffer 0 + ConstBuffer 1
  402. check_schedule(out, 0)
  403. def test_zero_size(self):
  404. x = Tensor.empty(2, 3, 0)
  405. out = x + 1
  406. check_schedule(out, 0, filter_sink=False)
  407. def test_reduce_permute_nofuse(self):
  408. x = Tensor.empty(32, 32, 32)
  409. y = Tensor.empty(32, 32)
  410. out = x.sum(axis=2).T+y
  411. check_schedule(out, 2)
  412. def test_two_elus_sum(self):
  413. x = Tensor.empty(32, 32)
  414. y = Tensor.empty(32, 32)
  415. out = x.sum(1).relu().elu() + y.sum(1).relu().elu()
  416. check_schedule(out, 2)
  417. # multireduce spec
  418. @unittest.skipUnless(getenv("SPLIT_REDUCEOP", 1), "Testing split reducop requires SPLIT_REDUCEOP")
  419. def test_preserve_multistage_reduce(self):
  420. big_enough = getenv("REDUCEOP_SPLIT_THRESHOLD", 32768)
  421. x = Tensor.randn(big_enough).realize()
  422. out = (x - x.max(keepdim=True)).max()
  423. run_schedule(check_schedule(out, 4))
  424. np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(keepdims=True)).max())
  425. def test_multistage_reduce(self):
  426. x = Tensor.empty(32, 32, 32)
  427. out = x.sum(2).relu().sum(1)
  428. check_schedule(out, 2)
  429. def test_multistage_reduce_fork(self):
  430. x = Tensor.empty(32, 32, 32)
  431. x = x.sum(2)
  432. out2 = x + 1
  433. out = x.relu().sum(1) + out2[0]
  434. check_schedule(out, 2)
  435. # multireduce spec
  436. def test_example_matmul(self):
  437. x = Tensor.eye(64, requires_grad=True)
  438. y = Tensor.eye(64, requires_grad=True)
  439. z = y.matmul(x).sum()
  440. z.backward()
  441. out = x.grad.contiguous()
  442. run_schedule(check_schedule(out, 2))
  443. np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
  444. def test_contiguous_add(self):
  445. x = Tensor.empty(32)
  446. y = Tensor.empty(32)
  447. z = Tensor.empty(32)
  448. out = (x+y).contiguous()+z
  449. check_schedule(out, 2)
  450. def test_double_sum_ref(self):
  451. x = Tensor.empty(32, 32, 32)
  452. x = x.sum(2)
  453. out = x + x[:, 4]
  454. check_schedule(out, 2)
  455. def test_reduce_shrink(self):
  456. x = Tensor.empty(32, 32)
  457. y = Tensor.empty(16)
  458. x = x.sum(1)
  459. x = x[:16]
  460. out = x + y
  461. check_schedule(out, 2) # TODO: this should be 1
  462. # multireduce spec
  463. def test_multireduce_shrink(self):
  464. Tensor.manual_seed(0)
  465. a = Tensor.randn(32, 32).realize()
  466. b = Tensor.randn(32, 32).realize()
  467. c = Tensor.randn(16).realize()
  468. a_out = a.sum(1)
  469. a_out = a_out[:16]
  470. b_out = b.sum(1)
  471. b_out = b_out[:16]
  472. out = a_out + b_out + c
  473. # run_schedule(check_schedule(out, 2)) # TODO: this should be 1 (can we make it 1 with the new linearizer?)
  474. run_schedule(check_schedule(out, 3))
  475. np.testing.assert_allclose(out.numpy(), a.numpy().sum(axis=1)[:16] + b.numpy().sum(axis=1)[:16] + c.numpy(), atol=1e-4, rtol=1e-4)
  476. # broken due to const folding and two contiguous are different kernels
  477. def test_const_no_recompute(self):
  478. x = Tensor(2) + Tensor(2)
  479. y = Tensor(2) + Tensor(2)
  480. out = x.contiguous() + y.contiguous()
  481. with self.assertRaises(KernelCountException): check_schedule(out, 2, filter_sink=False)
  482. # multireduce spec
  483. def test_reduce_same_size(self):
  484. Tensor.manual_seed(0)
  485. a = Tensor.randn(4, 4).realize()
  486. out0 = a.sum() + 2
  487. out1 = a.sum() + 4
  488. out2 = out0 * out1
  489. run_schedule(check_schedule([out0, out1, out2], 1))
  490. np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-6)
  491. np.testing.assert_allclose(out1.numpy(), out1_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
  492. np.testing.assert_allclose(out2.numpy(), out0_np*out1_np, atol=1e-4, rtol=1e-6)
  493. # multireduce spec
  494. def test_reduce_multiple_paths(self):
  495. Tensor.manual_seed(0)
  496. a = Tensor.randn(4, 4).realize()
  497. out0 = a.sum().exp2()
  498. # out1 has two paths to a.sum()
  499. out1 = a.sum() + out0
  500. run_schedule(check_schedule([out0, out1], 1))
  501. np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
  502. np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+out0_np, atol=1e-4, rtol=1e-6)
  503. # multireduce spec
  504. def test_multireduce_reduce_multiple_paths(self):
  505. Tensor.manual_seed(0)
  506. a = Tensor.randn(4, 4).realize()
  507. out0 = a.sum().exp2()
  508. out1 = a.sum() + out0
  509. b = (a + out0 + out1)
  510. out2 = b.sum().exp2()
  511. out3 = b.sum() + out2
  512. # run_schedule(check_schedule([out0, out1, out2, out3], 1))
  513. run_schedule(check_schedule([out0, out1, out2, out3], 2))
  514. np.testing.assert_allclose(out0.numpy(), np_out0:=np.exp2(a.numpy().sum()), atol=1e-4, rtol=1e-4)
  515. np.testing.assert_allclose(out1.numpy(), np_out1:=a.numpy().sum()+np_out0, atol=1e-4, rtol=1e-4)
  516. np_b = (a.numpy() + np_out0 + np_out1)
  517. np.testing.assert_allclose(out2.numpy(), np_out2:=np.exp2(np_b.sum()), atol=1e-4, rtol=1e-4)
  518. np.testing.assert_allclose(out3.numpy(), np_b.sum()+np_out2, atol=1e-4, rtol=1e-4)
  519. # multireduce spec
  520. def test_reduce_ext_reduce_child(self):
  521. Tensor.manual_seed(0)
  522. a = Tensor.randn(4, 4).realize()
  523. b = Tensor.randn(4, 4).realize()
  524. # b.sum() is not a descendant of the fused nodes
  525. out0 = a.sum() + b.sum() + 2
  526. out1 = a.sum() + b.sum() + 4
  527. # run_schedule(check_schedule([out0, out1], 1))
  528. run_schedule(check_schedule([out0, out1], 4))
  529. np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+b.numpy().sum()+2, atol=1e-4, rtol=1e-4)
  530. np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy().sum()+4, atol=1e-4, rtol=1e-4)
  531. # multireduce spec
  532. def test_reduce_multiple_paths_midreduce(self):
  533. Tensor.manual_seed(0)
  534. a = Tensor.randn(4, 4).realize()
  535. r = a.sum()
  536. out0 = r.exp2()
  537. # reduce node in the indirect path from r to out2
  538. out1 = (a - out0).max()
  539. out2 = r + out1
  540. # run_schedule(check_schedule([r, out0, out1, out2], 1))
  541. run_schedule(check_schedule([r, out0, out1, out2], 4))
  542. np.testing.assert_allclose(r.numpy(), r_np:=a.numpy().sum(), atol=1e-4, rtol=1e-4)
  543. np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(r_np), atol=1e-4, rtol=1e-4)
  544. np.testing.assert_allclose(out1.numpy(), out1_np:=(a.numpy() - out0_np).max(), atol=1e-4, rtol=1e-4)
  545. np.testing.assert_allclose(out2.numpy(), r_np + out1_np, atol=1e-4, rtol=1e-4)
  546. # multireduce spec
  547. def test_reduce_multiple_paths_midreduce_fused(self):
  548. Tensor.manual_seed(0)
  549. a = Tensor.randn(4, 4).realize()
  550. b = Tensor.randn(4, 4).realize()
  551. out0 = a.sum() + 4
  552. out1 = b.max() + out0*2
  553. out2 = a.sum() + out1
  554. # run_schedule(check_schedule([out0, out1, out2], 1))
  555. run_schedule(check_schedule([out0, out1, out2], 4))
  556. np.testing.assert_allclose(out0.numpy(), out0_np:=a.numpy().sum()+4, atol=1e-4, rtol=1e-6)
  557. np.testing.assert_allclose(out1.numpy(), out1_np:=b.numpy().max() + out0_np*2, atol=1e-4, rtol=1e-6)
  558. np.testing.assert_allclose(out2.numpy(), a.numpy().sum() + out1_np, atol=1e-4, rtol=1e-6)
  559. # multireduce spec
  560. def test_reduce_multiple_paths_midexpand(self):
  561. Tensor.manual_seed(0)
  562. a = Tensor.randn(4, 4).realize()
  563. b = Tensor.randn(4, 4, 4).realize()
  564. r = a.sum()
  565. out0 = r.exp2()
  566. # e1 is in the indirect path from a.sum() to out1
  567. e = b + out0
  568. out1 = r + e[0][0][0]
  569. # run_schedule(check_schedule([r, out0, out1, e], 3)) # 1 or 2 or 3? should be 1 (one reduce) but the different outputs might make it 3
  570. run_schedule(check_schedule([r, out0, out1, e], 4))
  571. np.testing.assert_allclose(r.numpy(), r_np:=a.numpy().sum(), atol=1e-4, rtol=1e-4)
  572. np.testing.assert_allclose(out0.numpy(), out0_np:=np.exp2(r_np), atol=1e-4, rtol=1e-4)
  573. np.testing.assert_allclose(e.numpy(), e_np:=b.numpy() + out0_np, atol=1e-4, rtol=1e-4)
  574. np.testing.assert_allclose(out1.numpy(), r_np + e_np[0][0][0], atol=1e-4, rtol=1e-4)
  575. # changed by multireduce
  576. def test_reduce_expand_child(self):
  577. Tensor.manual_seed(0)
  578. a = Tensor.randn((32, 32, 32)).realize()
  579. b = Tensor.randn((1, 16)).realize()
  580. out0 = a.sum() + 2
  581. out1 = a.sum() + b
  582. # run_schedule(check_schedule([out0, out1], 2))
  583. run_schedule(check_schedule([out0, out1], 4))
  584. np.testing.assert_allclose(out0.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
  585. np.testing.assert_allclose(out1.numpy(), a.numpy().sum()+b.numpy(), atol=1e-4, rtol=1e-4)
  586. def test_reduce_shrink_child(self):
  587. a = Tensor.empty(100, 100)
  588. b = Tensor.empty(10,)
  589. c = a.sum() + b[0]
  590. d = a.sum() + 2
  591. check_schedule([c, d], 1)
  592. def test_reduce_multiple_paths_midshrink(self):
  593. a = Tensor.empty(4, 4)
  594. r = a.sum(axis=1)
  595. out0 = r.exp2()
  596. out1 = out0[0] + out0
  597. check_schedule([r, out0, out1], 3)
  598. def test_reduce_shrink_output(self):
  599. a = Tensor.empty(4, 4)
  600. r = a.sum(keepdim=True)
  601. out0 = r.exp2()
  602. out1 = out0[0] + Tensor.empty(1, )
  603. check_schedule([r, out0, out1], 3)
  604. # multireduce spec
  605. def test_std_multireduce_fusion(self):
  606. Tensor.manual_seed(0)
  607. x = Tensor.randn(4, 32).realize()
  608. out = x.std(-1)
  609. run_schedule(check_schedule(out, 2))
  610. np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
  611. # multireduce spec
  612. def test_argmin_multireduce_fusion(self):
  613. Tensor.manual_seed(0)
  614. x = Tensor.randn(4, 32).realize()
  615. out = x.argmin(-1)
  616. run_schedule(check_schedule(out, 3))
  617. np.testing.assert_equal(out.numpy(), x.numpy().argmin(axis=-1))
  618. # multireduce spec
  619. def test_argmax_multireduce_fusion(self):
  620. Tensor.manual_seed(0)
  621. x = Tensor.randn(4, 32).realize()
  622. out = x.argmax(-1)
  623. run_schedule(check_schedule(out, 3))
  624. np.testing.assert_equal(out.numpy(), x.numpy().argmax(axis=-1))
  625. # multireduce spec
  626. def test_scaled_dot_product_attention_multireduce_fusion(self):
  627. Tensor.manual_seed(0)
  628. q = Tensor.randn(32,8,16,64).realize()
  629. k = Tensor.randn(32,8,16,64).realize()
  630. v = Tensor.randn(32,8,16,64).realize()
  631. out = Tensor.scaled_dot_product_attention(q,k,v)
  632. check_schedule(out, 5) # correctness checked in test_ops
  633. # multireduce spec
  634. def test_ugly_reduceop_pairing(self):
  635. Tensor.manual_seed(0)
  636. a = Tensor.randn(4, 32).realize()
  637. b = Tensor.randn(4, 32).realize()
  638. c = Tensor.randn(4, 32).realize()
  639. out = (c * a.sum(-1, keepdim=True)).sum(-1) + (b * a.sum(-1, keepdim=True)).sum(-1) # a.sum has >1 children but should still fuse
  640. # run_schedule(check_schedule(out, 1))
  641. run_schedule(check_schedule(out, 3))
  642. np.testing.assert_allclose(out.numpy(), \
  643. (c.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1) + (b.numpy()*a.numpy().sum(axis=-1,keepdims=True)).sum(-1), atol=1e-4, rtol=1e-4)
  644. # multireduce spec
  645. def test_reduce_expand_reduce_fusion(self):
  646. Tensor.manual_seed(0)
  647. a = Tensor.randn(4, 32).realize()
  648. out = (a+a.sum(-1, keepdim=True)).sum(-1)
  649. # run_schedule(check_schedule(out, 1))
  650. run_schedule(check_schedule(out, 2))
  651. np.testing.assert_allclose(out.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
  652. # multireduce spec
  653. def test_reduce_expand_reduce_expand_fusion(self):
  654. Tensor.manual_seed(0)
  655. a = Tensor.randn(4, 32).realize()
  656. out = a+(a+a.sum(-1,keepdim=True)).sum(-1, keepdim=True)
  657. # run_schedule(check_schedule(out, 2))
  658. run_schedule(check_schedule(out, 3))
  659. np.testing.assert_allclose(out.numpy(), \
  660. a.numpy()+(a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
  661. # multireduce spec
  662. def test_branching_reduces_and_expands_fusion(self):
  663. Tensor.manual_seed(0)
  664. a = Tensor.randn(4, 32).realize()
  665. out0 = a+a.sum(-1, keepdim=True)
  666. out1 = out0.sum(-1)
  667. # run_schedule(check_schedule(out, 2))
  668. run_schedule(check_schedule([out0, out1], 3))
  669. np.testing.assert_allclose(out0.numpy(), a.numpy()+a.numpy().sum(axis=-1,keepdims=True), atol=1e-4, rtol=1e-4)
  670. np.testing.assert_allclose(out1.numpy(), (a.numpy()+a.numpy().sum(axis=-1,keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
  671. # multireduce spec
  672. def test_multireduce_fusion_simple_sequential(self):
  673. Tensor.manual_seed(0)
  674. x = Tensor.randn(4, 32).realize()
  675. y = Tensor.randn(4, 32).realize()
  676. out = (y + x.sum(axis=-1, keepdim=True)).sum(axis=-1)
  677. # run_schedule(check_schedule(out, 1))
  678. run_schedule(check_schedule(out, 2))
  679. np.testing.assert_allclose(out.numpy(), (y.numpy() + x.numpy().sum(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
  680. # multireduce spec
  681. def test_multireduce_fusion_simple_parallel(self):
  682. Tensor.manual_seed(0)
  683. x = Tensor.randn(4, 32).realize()
  684. y = Tensor.randn(4, 32).realize()
  685. out = y.sum(axis=-1) + x.sum(axis=-1)
  686. # run_schedule(check_schedule(out, 1))
  687. run_schedule(check_schedule(out, 2))
  688. np.testing.assert_allclose(out.numpy(), y.numpy().sum(axis=-1) + x.numpy().sum(axis=-1), atol=1e-4, rtol=1e-4)
  689. # multireduce spec
  690. def test_multireduce_fusion_sequential(self):
  691. Tensor.manual_seed(0)
  692. x = Tensor.randn(4, 32).realize()
  693. out = x.std(-1)
  694. # run_schedule(check_schedule(out, 1))
  695. run_schedule(check_schedule(out, 2))
  696. np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
  697. # multireduce spec
  698. def test_multireduce_fusion_parallel(self):
  699. Tensor.manual_seed(0)
  700. x = Tensor.randn(4, 32).realize()
  701. y = Tensor.randn(4, 32).realize()
  702. out = x.std(-1) + y.std(-1)
  703. # run_schedule(check_schedule(out, 1))
  704. run_schedule(check_schedule(out, 4))
  705. np.testing.assert_allclose(out.numpy(), x.numpy().std(axis=-1, ddof=1) + y.numpy().std(axis=-1, ddof=1), atol=1e-4, rtol=1e-4)
  706. # multireduce spec
  707. def test_multireduce_diffops_sequential(self):
  708. Tensor.manual_seed(0)
  709. x = Tensor.randn(4, 32).realize()
  710. out = (x - x.max(-1, keepdim=True)).sum(-1)
  711. # run_schedule(check_schedule(out, 1))
  712. run_schedule(check_schedule(out, 2))
  713. np.testing.assert_allclose(out.numpy(), (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1), atol=1e-4, rtol=1e-4)
  714. # multireduce spec
  715. def test_multireduce_fusion_diffops_parallel(self):
  716. Tensor.manual_seed(0)
  717. x = Tensor.randn(4, 32).realize()
  718. y = Tensor.randn(4, 32).realize()
  719. out = x.sum(-1) + y.max(-1)
  720. # run_schedule(check_schedule(out, 1))
  721. run_schedule(check_schedule(out, 2))
  722. np.testing.assert_allclose(out.numpy(), x.numpy().sum(axis=-1) + y.numpy().max(axis=-1), atol=1e-4, rtol=1e-4)
  723. # multireduce spec
  724. def test_multireduce_fusion_sequential_and_parallel(self):
  725. Tensor.manual_seed(0)
  726. x = Tensor.randn(4, 32).realize()
  727. y = Tensor.randn(4, 32).realize()
  728. mu = (x - x.max(axis=-1, keepdim=True)).mean(axis=-1, keepdim=True) + (y - y.max(axis=-1, keepdim=True)).mean(axis=-1, keepdim=True)
  729. out = [((x - mu).square().sum(-1)/x.shape[-1]).sqrt(), ((y - mu).square().sum(-1)/y.shape[-1]).sqrt()]
  730. np_mu = (x.numpy() - x.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True) + \
  731. (y.numpy() - y.numpy().max(axis=-1, keepdims=True)).mean(axis=-1, keepdims=True)
  732. # run_schedule(check_schedule(out, 1))
  733. run_schedule(check_schedule(out, 6))
  734. np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4)
  735. np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
  736. # multireduce spec
  737. def test_multimatmul_fusion(self):
  738. Tensor.manual_seed(0)
  739. a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
  740. c,d = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
  741. out = a@b + c@d
  742. # run_schedule(check_schedule(out, 1))
  743. run_schedule(check_schedule(out, 2))
  744. np.testing.assert_allclose(out.numpy(), a.numpy()@b.numpy() + c.numpy()@d.numpy(), atol=1e-4, rtol=1e-4)
  745. def test_softmax_fusion(self):
  746. Tensor.manual_seed(0)
  747. x = Tensor.randn(4, 12, 64, 64).realize()
  748. out = x.softmax()
  749. # run_schedule(check_schedule(out, 2))
  750. run_schedule(check_schedule(out, 3))
  751. expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True)
  752. np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
  753. # changed by: multireduce spec
  754. def test_layernorm_onelayer_fusion(self):
  755. Tensor.manual_seed(0)
  756. layer = nn.LayerNorm([10, 10])
  757. layer.weight = Tensor.randn(10,10).realize()
  758. layer.bias = Tensor.randn(10,10).realize()
  759. x = Tensor.randn(20, 5, 10, 10).realize()
  760. out = layer(x)
  761. # run_schedule(check_schedule(out, 2))
  762. run_schedule(check_schedule(out, 3))
  763. y = (x.numpy() - x.numpy().mean(layer.axis, keepdims=True))
  764. expected = y / np.sqrt((y*y).mean(layer.axis, keepdims=True) + layer.eps)
  765. np.testing.assert_allclose(out.numpy(), expected * layer.weight.numpy() + layer.bias.numpy(), atol=1e-4, rtol=1e-4)
  766. def test_scaled_dot_product_attention_fusion(self):
  767. x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
  768. out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m)
  769. check_schedule(out, 5)
  770. def test_scaled_dot_product_attention_causal_fusion(self):
  771. x, y, z, m = (Tensor.empty(32, 8, 16, 16) for _ in range(4))
  772. out = Tensor.scaled_dot_product_attention(x, y, z, attn_mask=m, is_causal=True)
  773. check_schedule(out, 6)
  774. def test_adam_step_fusion(self):
  775. with Tensor.train():
  776. x = Tensor.empty(4, 64, 768)
  777. layer = nn.Linear(768, 768*4)
  778. opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
  779. layer(x).relu().sum().backward()
  780. check_schedule(opt.schedule_step(), 11)
  781. def test_adam_conv_fuse(self):
  782. with Tensor.train():
  783. img = Tensor.empty(2,3,4,4)
  784. c1 = nn.Conv2d(3,32,3)
  785. opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
  786. opt.zero_grad()
  787. c1(img).relu().sum().backward()
  788. check_schedule(opt.schedule_step(), 11)
  789. def test_adam_2convs_fuse(self):
  790. with Tensor.train():
  791. img = Tensor.empty(2,3,4,4)
  792. c1 = nn.Conv2d(3,16,3,bias=False)
  793. c2 = nn.Conv2d(16,32,3,bias=False)
  794. opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
  795. opt.zero_grad()
  796. c2(c1(img).relu()).relu().sum().backward()
  797. check_schedule(opt.schedule_step(), 13)
  798. def test_sgd_conv_fuse(self):
  799. with Tensor.train():
  800. img = Tensor.empty(2,3,4,4)
  801. c1 = nn.Conv2d(3,32,3)
  802. opt = nn.optim.SGD(nn.state.get_parameters(c1))
  803. opt.zero_grad()
  804. c1(img).relu().sum().backward()
  805. check_schedule(opt.schedule_step(), 7)
  806. def test_sgd_2convs_fuse(self):
  807. with Tensor.train():
  808. img = Tensor.empty(2,3,4,4)
  809. c1 = nn.Conv2d(3,16,3,bias=False)
  810. c2 = nn.Conv2d(16,32,3,bias=False)
  811. opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
  812. opt.zero_grad()
  813. c2(c1(img).relu()).relu().sum().backward()
  814. check_schedule(opt.schedule_step(), 7)
  815. def test_fold_2convs_sgd_nesterov_momentum_wd(self):
  816. with Tensor.train():
  817. img = Tensor.empty(2,3,4,4)
  818. c1 = nn.Conv2d(3,16,3,bias=False)
  819. c2 = nn.Conv2d(16,32,3,bias=False)
  820. opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
  821. opt.zero_grad()
  822. c2(c1(img).relu()).relu().sum().backward()
  823. check_schedule(opt.schedule_step(), 9)
  824. def test_sgd_4convs_fuse(self):
  825. with Tensor.train():
  826. img = Tensor.empty(2,3,64,64)
  827. c1 = nn.Conv2d(3,4,3,bias=False)
  828. c2 = nn.Conv2d(4,8,3,bias=False)
  829. c3 = nn.Conv2d(8,16,3,bias=False)
  830. c4 = nn.Conv2d(16,32,3,bias=False)
  831. opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
  832. opt.zero_grad()
  833. c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
  834. check_schedule(opt.schedule_step(), 22)
  835. @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
  836. def test_prefer_half_buffer(self):
  837. x = Tensor.ones(4).contiguous().realize()
  838. # y = Tensor.ones(4).contiguous().realize()
  839. z = Tensor.ones(4, 4).contiguous().realize()
  840. # should not create extra kernel if output will be realized anyways
  841. dummy = x.sum().half().float()
  842. check_schedule(dummy, 1)
  843. dummy = x.sum().half().float().contiguous() + 1
  844. check_schedule(dummy, 2)
  845. # shared between two outputs
  846. shared = x.sum().half().float()
  847. a = shared * 2
  848. b = shared * 3
  849. sched = check_schedule([a, b], 1)
  850. for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs)
  851. # reduce
  852. a = z.sum(axis=0).half().float().sum(axis=0)
  853. sched = check_schedule(a, 2)
  854. for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs)
  855. # expand
  856. # expand will realize just after the .float(), so requires change to realize-before-expand
  857. # normal = (x.sum().half().float().reshape(1) * y).sum()
  858. # sched = check_schedule(normal, 2)
  859. # for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs[:-1])
  860. # parallel reduce
  861. # a = x.sum().half().float() * y.sum().half().float()
  862. # b = a + 1
  863. # c = a + 2
  864. # sched = check_schedule([b, c], 4)
  865. # doesn't store either in half because it doesn't chase
  866. def test_reduce_simple_chase(self):
  867. a = Tensor.empty(4, 4, 4)
  868. r = a.sum(0) + 6
  869. b = r.sum(0) * 4
  870. c = r.sum(1) * 2
  871. schedule = check_schedule([b, c], 3)
  872. assert schedule[0].ast.src[0].src[0].op is BinaryOps.ADD
  873. # multireduce spec
  874. def test_multireduce_simple_chase(self):
  875. Tensor.manual_seed(0)
  876. a = Tensor.randn(4, 4, 4).realize()
  877. r = (a + (a.sum(0, keepdim=True) + 6)).sum(0) * 2
  878. b = r.sum(0) + 8
  879. c = r.sum(1) + 12
  880. np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2
  881. # schedule = check_schedule([b,c], 3)
  882. # assert schedule[0].ast[0].src[0].op is BinaryOps.MUL
  883. schedule = check_schedule([b,c], 4)
  884. run_schedule(schedule)
  885. np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4)
  886. np.testing.assert_allclose(c.numpy(), np_r.sum(1) + 12, atol=1e-4, rtol=1e-4)
  887. def test_push_permute_chase(self):
  888. a = Tensor.empty(4, 4, 4)
  889. b = Tensor.empty(4, 4)
  890. r = a.sum(2) + b
  891. d = r.T * 4
  892. e = r * d
  893. schedule = check_schedule([d, e], 3)
  894. assert schedule[0].ast.src[0].src[0].op is BinaryOps.ADD
  895. # multireduce spec
  896. def test_multireduce_push_permute_chase(self):
  897. Tensor.manual_seed(0)
  898. a = Tensor.randn(4, 4, 4).realize()
  899. b = Tensor.randn(4, 4).realize()
  900. r = a.sum(2) + b
  901. d = r.T * 4
  902. e = r * (d + a).sum(2)
  903. schedule = check_schedule([d, e], 3) # make sure it doesn't fuse
  904. assert schedule[0].ast.src[0].src[0].op is BinaryOps.ADD
  905. run_schedule(schedule)
  906. np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4)
  907. np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4)
  908. def test_push_shrink_chase(self):
  909. a = Tensor.empty(16, 16)
  910. b = Tensor.empty(4)
  911. c = Tensor.empty(16, )
  912. r = a.sum(1) + c
  913. d = r[:4] * b
  914. schedule = check_schedule(d, 2)
  915. assert schedule[0].ast.src[0].src[0].op is BinaryOps.ADD
  916. # multireduce spec
  917. def test_multireduce_push_shrink_chase(self):
  918. Tensor.manual_seed(0)
  919. a = Tensor.randn(16, 16).realize()
  920. b = Tensor.randn(4).realize()
  921. c = Tensor.randn(16, ).realize()
  922. d = Tensor.randn(16, 16).realize()
  923. r = a.sum(1) + c
  924. out = r[:4] * b + d.sum(1)[:4]
  925. # schedule = check_schedule(out, 2)
  926. schedule = check_schedule(out, 3)
  927. assert schedule[0].ast.src[0].src[0].op is BinaryOps.ADD
  928. run_schedule(schedule)
  929. np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4)
  930. def test_midreduce_nochase(self):
  931. a = Tensor.empty(16, 16)
  932. b = (a.sum(0) + a.max(1)) + 2
  933. schedule = check_schedule(b, 2)
  934. assert schedule[0].ast.src[0].src[0].op is ReduceOps.MAX
  935. # multireduce spec
  936. def test_multireduce_midreduce_nochase(self):
  937. Tensor.manual_seed(0)
  938. a = Tensor.randn(16, 16).realize()
  939. b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
  940. # schedule = check_schedule(b, 2)
  941. schedule = check_schedule(b, 4)
  942. assert schedule[0].ast.src[0].src[0].op is ReduceOps.MAX
  943. run_schedule(schedule)
  944. np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
  945. # changed by: multireduce spec
  946. # pattern in test_transformer
  947. def test_partial_fuse1(self):
  948. Tensor.manual_seed(0)
  949. a = Tensor.randn(16, 16).realize()
  950. b = Tensor.randn(16, 16).realize()
  951. c = a.sum() + 2
  952. d = (a.sum() - b.sum()) * 4
  953. # run_schedule(check_schedule([c, d], 1))
  954. run_schedule(check_schedule([c, d], 3))
  955. np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
  956. np.testing.assert_allclose(d.numpy(), (a.numpy().sum() - b.numpy().sum()) * 4, atol=1e-4, rtol=1e-4)
  957. # changed by: multireduce spec
  958. # pattern in conv
  959. def test_partial_fuse2(self):
  960. Tensor.manual_seed(0)
  961. a = Tensor.randn(16, 16).realize()
  962. b = Tensor.randn(16, 16).realize()
  963. c = a.sum() + 2
  964. d = b.sum() - c
  965. # run_schedule(check_schedule([c, d], 1))
  966. run_schedule(check_schedule([c, d], 2))
  967. np.testing.assert_allclose(c.numpy(), a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
  968. np.testing.assert_allclose(d.numpy(), b.numpy().sum()-(a.numpy().sum()+2), atol=1e-4, rtol=1e-4)
  969. # changed by: multireduce spec
  970. # pattern in adam
  971. def test_partial_fuse3(self):
  972. Tensor.manual_seed(0)
  973. a = Tensor.randn(16, 16).realize()
  974. b = Tensor.randn(16, 16).realize()
  975. c = a.sum() + 2
  976. d = a.sum() * 2
  977. e = c * d
  978. f = b.sum() - e
  979. # run_schedule(check_schedule([c, d, e, f], 1))
  980. run_schedule(check_schedule([c, d, e, f], 2))
  981. np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
  982. np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
  983. np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
  984. np.testing.assert_allclose(f.numpy(), b.numpy().sum() - e_np, atol=1e-4, rtol=1e-4)
  985. # changed by: multireduce spec
  986. def test_partial_fuse4(self):
  987. Tensor.manual_seed(0)
  988. a = Tensor.randn(16, 16).realize()
  989. b = Tensor.randn(16, 16).realize()
  990. c = a.sum() + 2
  991. d = a.sum() * 2
  992. e = c * d
  993. f = (b - d).sum() - e
  994. # run_schedule(check_schedule([c, d, e, f], 1))
  995. run_schedule(check_schedule([c, d, e, f], 3))
  996. np.testing.assert_allclose(c.numpy(), c_np:=a.numpy().sum()+2, atol=1e-4, rtol=1e-4)
  997. np.testing.assert_allclose(d.numpy(), d_np:=a.numpy().sum()*2, atol=1e-4, rtol=1e-4)
  998. np.testing.assert_allclose(e.numpy(), e_np:=c_np*d_np, atol=1e-4, rtol=1e-4)
  999. np.testing.assert_allclose(f.numpy(), (b.numpy()-d_np).sum()-e_np, atol=1e-4, rtol=1e-4)
  1000. def test_pad_reduce_safe(self):
  1001. Tensor.manual_seed(0)
  1002. a = Tensor.rand(3, 4, 5).realize()
  1003. b = Tensor.rand(3, 4, 5).realize()
  1004. out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
  1005. run_schedule(check_schedule(out, 1))
  1006. np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
  1007. # multireduce spec
  1008. def test_multireduce_pad_reduce_safe(self):
  1009. Tensor.manual_seed(0)
  1010. a = Tensor.randn(3, 4, 5).realize()
  1011. b = Tensor.randn(3, 4, 5).realize()
  1012. out = (a.pad(((0, 1), (0, 1), (0, 1)), 1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), 1.0).sum()).contiguous()
  1013. # run_schedule(check_schedule(out, 1))
  1014. run_schedule(check_schedule(out, 2))
  1015. np.testing.assert_allclose(out.numpy(), np.pad(a.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(keepdims=True) + \
  1016. np.pad(b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-4, rtol=1e-4)
  1017. def test_pad_reduce_unsafe(self):
  1018. Tensor.manual_seed(0)
  1019. a = Tensor.rand(3, 4, 5).realize()
  1020. out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
  1021. run_schedule(check_schedule(out, 2))
  1022. np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), rtol=1e-6)
  1023. # multireduce spec
  1024. def test_multireduce_pad_reduce_unsafe(self):
  1025. Tensor.manual_seed(0)
  1026. a = Tensor.randn(3, 4, 5).abs().realize()
  1027. b = Tensor.randn(3, 4, 5).abs().realize()
  1028. out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
  1029. # run_schedule(check_schedule(out, 1))
  1030. run_schedule(check_schedule(out, 4))
  1031. np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \
  1032. b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-4, rtol=1e-6)
  1033. def test_shrink_pad_safe(self):
  1034. a = Tensor.ones((3, )).contiguous().realize()
  1035. b = Tensor.ones((3, )).contiguous().realize()
  1036. out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous()
  1037. run_schedule(check_schedule(out, 1))
  1038. np.testing.assert_equal(out.numpy(), [2, 0])
  1039. def test_shrink_pad_unsafe(self):
  1040. a = Tensor.ones((3, )).contiguous().realize()
  1041. out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
  1042. run_schedule(check_schedule(out, 2))
  1043. np.testing.assert_equal(out.numpy(), [2, 0])
  1044. def test_base_change_shrink_pad(self):
  1045. a = Tensor.ones(3, 3).contiguous().realize()
  1046. b = a.exp2()
  1047. c = b[:-1, :-1]
  1048. d = c.pad(((0, 1), (0, 1))) * 2
  1049. run_schedule(check_schedule(d, 2))
  1050. np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:-1, :-1], ((0, 1), (0, 1)))*2)
  1051. def test_base_change_expand_pad(self):
  1052. a = Tensor.ones(3, 3).contiguous().realize()
  1053. b = a.exp2()
  1054. c = b[:, None, :]
  1055. d = c.pad(((0, 0), (1, 1), (0, 0))) * 2
  1056. run_schedule(check_schedule(d, 2))
  1057. np.testing.assert_equal(d.numpy(), np.pad(np.exp2(a.numpy())[:, None, :], ((0, 0), (1, 1), (0, 0)))*2)
  1058. # TODO like openpilot with imagef
  1059. @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
  1060. def test_base_change_expand_expand(self):
  1061. a = Tensor.ones(4, 4).contiguous().realize()
  1062. b = a.cast(dtypes.half).expand(2, 4, 4)
  1063. c = b.cast(dtypes.int).expand(2, 2, 4, 4)
  1064. run_schedule(check_schedule(c, 2))
  1065. np.testing.assert_equal(c.numpy(), np.ones(((2, 2, 4, 4)), dtype=np.int32))
  1066. def test_base_change_pad_expand(self):
  1067. a = Tensor.full((4, 4), 1.).contiguous().realize()
  1068. b = Tensor.full((4, 4), 2.).contiguous().realize()
  1069. c = (a + b).pad(((1, 1), (1, 1)))
  1070. d = c.cast(dtypes.int).expand((2, 6, 6)) * 4
  1071. run_schedule(check_schedule(d, 2))
  1072. c_np = np.pad((np.full((4, 4), 2., dtype=np.float32) + np.full((4, 4), 1., dtype=np.float32)), ((1, 1), (1, 1)), constant_values=0.0)
  1073. np.testing.assert_equal(d.numpy(), np.broadcast_to(c_np.astype(np.half), (2, *c_np.shape)) * 4)
  1074. def test_pad_reduce_unsafe_multiview_st(self):
  1075. P = Tensor.ones(3, 3).contiguous()
  1076. sums = P.sum(axis=1, keepdim=True)
  1077. P /= sums
  1078. p = P[0]
  1079. p = p.pad(((1, 0), ))
  1080. p = p.repeat([2])
  1081. run_schedule(check_schedule(p, 3))
  1082. tiny_ret = p.numpy()
  1083. P = np.ones((3, 3), dtype=np.float32)
  1084. sums = P.sum(axis=1, keepdims=True)
  1085. P /= sums
  1086. p = P[0]
  1087. p = np.pad(p, (1, 0), 'constant')
  1088. p = np.tile(p, 2)
  1089. np.testing.assert_allclose(tiny_ret, p)
  1090. @unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
  1091. def test_bitcast_subbufer(self):
  1092. a = Tensor.empty(1, dtype=dtypes.float32).realize()
  1093. b = CycleBitcast.apply(a)
  1094. check_schedule(b, 2) # this should fuse when it makes sense
  1095. def test_bitcast_disable_subbufer(self):
  1096. a = Tensor.empty(1, dtype=dtypes.float32).realize()
  1097. b = CycleBitcast.apply(a, allow_buffer_view=False)
  1098. check_schedule(b, 1)
  1099. def test_reduceop_reshape_dont_push(self):
  1100. Tensor.manual_seed(0)
  1101. x = Tensor.randn(10, 20).realize()
  1102. out = x.argmax(1)
  1103. run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
  1104. class CycleBitcast(Function):
  1105. def forward(self, x: LazyBuffer, allow_buffer_view=True):
  1106. a = x.e(UnaryOps.NEG).cast(dtypes.int32, True, allow_buffer_view)
  1107. b = x.cast(dtypes.int32, True, allow_buffer_view)
  1108. return a.e(BinaryOps.ADD, b)
  1109. if __name__ == '__main__':
  1110. unittest.main(verbosity=2)