test_multitensor.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. import unittest, functools, random
  2. from typing import List
  3. from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
  4. from tinygrad.ops import MetaOps, ReduceOps, BufferOps, BinaryOps
  5. from tinygrad.helpers import CI, getenv, prod, Context
  6. from tinygrad.nn.state import get_parameters, get_state_dict
  7. from tinygrad.engine.schedule import create_schedule
  8. from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner
  9. from tinygrad.multi import all_reduce, MultiLazyBuffer
  10. import numpy as np
  11. from hypothesis import given, strategies as strat, settings
  12. from test.helpers import is_dtype_supported
  13. settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
  14. settings.load_profile("my_profile")
  15. d0 = f"{Device.DEFAULT}:0"
  16. d1 = f"{Device.DEFAULT}:1"
  17. d2 = f"{Device.DEFAULT}:2"
  18. d3 = f"{Device.DEFAULT}:3"
  19. d4 = f"{Device.DEFAULT}:4"
  20. d5 = f"{Device.DEFAULT}:5"
  21. devices_2 = (d1, d2)
  22. devices_3 = (d1, d2, d3)
  23. devices_4 = (d1, d2, d3, d4)
  24. N = 128
  25. # shard_x is "data parallel"
  26. # shard_w is "model parallel"
  27. def _test_allreduce(t:Tensor):
  28. aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
  29. ts = t.shard(devices_4, 0).realize()
  30. b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, ts.lazydata.lbs), 0))
  31. b.realize()
  32. return aa, b
  33. @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
  34. class TestMultiTensor(unittest.TestCase):
  35. def test_to(self):
  36. X = Tensor.ones(256).contiguous().realize()
  37. X.to_(devices_2)
  38. for lb in X.lazydata.lbs:
  39. assert lb.shape == (256,)
  40. (X + X).realize()
  41. def test_shard(self):
  42. X = Tensor.ones(256).contiguous().realize()
  43. X.shard_(devices_2, 0)
  44. for lb in X.lazydata.lbs:
  45. assert lb.shape == (128,)
  46. (X + X).realize()
  47. def test_sharded_arange(self):
  48. sharded_arange = Tensor.arange(1000).shard(devices_2, 0)
  49. sharded_arange.realize()
  50. np.testing.assert_equal(sharded_arange.numpy(), np.arange(1000))
  51. def test_shard_no_recompile(self):
  52. X = Tensor.ones(256).contiguous().realize()
  53. X.shard_(devices_2, 0)
  54. out = (X + X)
  55. sched = create_schedule(out.lazydata.lbs)
  56. names = []
  57. for si, ei in zip(sched[:], lower_schedule(sched)):
  58. if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name)
  59. ei.run()
  60. assert names[-2] == names[-1], "function was relinearized"
  61. def test_sharded_memory(self):
  62. # Buffer may be stuck in track_cross_buffer
  63. for x in (d0, d1, d2, d3, d4): Device[x].synchronize()
  64. mem_base = GlobalCounters.mem_used
  65. X = Tensor.ones(256).contiguous().realize()
  66. assert GlobalCounters.mem_used-mem_base== X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
  67. X.shard_(devices_4).realize()
  68. for x in (d0, d1, d2, d3, d4): Device[x].synchronize()
  69. assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256 * 4, GlobalCounters.mem_used-mem_base
  70. X = Tensor.ones(256).contiguous().realize()
  71. assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
  72. X.shard_(devices_4, axis=0).realize()
  73. for x in (d0, d1, d2, d3, d4): Device[x].synchronize()
  74. assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
  75. X = Tensor.ones(256).realize()
  76. assert GlobalCounters.mem_used-mem_base == 0
  77. X.shard_(devices_4).realize()
  78. assert GlobalCounters.mem_used-mem_base == 0
  79. X = Tensor.ones(256).realize()
  80. assert GlobalCounters.mem_used-mem_base == 0
  81. X.shard_(devices_4, axis=0).realize()
  82. assert GlobalCounters.mem_used-mem_base == 0
  83. def test_shard_same_device(self):
  84. X = Tensor.ones(256).contiguous().realize()
  85. X.shard_((d1, X.device), 0)
  86. (X + X).realize()
  87. def test_shard_plus_one_sum(self):
  88. X = Tensor.ones(256).contiguous().realize()
  89. X.shard_((d1, d2), 0)
  90. (X + 1).sum().realize()
  91. def test_shard_plus_one_sum_d0(self):
  92. X = Tensor.ones(256).contiguous().realize()
  93. X.shard_((d0, d2), 0)
  94. (X + 1).sum().realize()
  95. def test_numpy(self):
  96. X = Tensor.ones(256)
  97. X.shard_((d1, d2), 0)
  98. np.testing.assert_allclose(X.numpy(), 1)
  99. def _test_simple_add_axis(self, shard_x, shard_w):
  100. X = Tensor.ones(256).contiguous().realize()
  101. W = Tensor.ones(256).contiguous().realize()
  102. X.shard_((d1, d2), shard_x)
  103. W.shard_((d1, d2), shard_w)
  104. O = X + W
  105. np.testing.assert_allclose(O.numpy(), 2)
  106. def test_simple_add(self): return self._test_simple_add_axis(None, None)
  107. def test_simple_add_X(self): return self._test_simple_add_axis(0, None)
  108. def test_simple_add_W(self): return self._test_simple_add_axis(None, 0)
  109. def test_simple_add_XW(self): return self._test_simple_add_axis(0, 0)
  110. def test_four_add(self):
  111. X = Tensor.ones(256, 256).contiguous().realize()
  112. W = Tensor.ones(256, 256).contiguous().realize()
  113. X.shard_(devices_4, 1)
  114. W.shard_(devices_4, None)
  115. O = X + W
  116. np.testing.assert_allclose(O.numpy(), 2)
  117. def test_elementwise_dtype(self):
  118. Tensor.manual_seed(0)
  119. X = Tensor.randn(8, 8).realize()
  120. W = Tensor.randn(8, 8).realize()
  121. X.shard_(devices_4, 0)
  122. W.shard_(devices_4, 0)
  123. O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
  124. np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
  125. @given(strat.sampled_from((4, 5)), strat.sampled_from((devices_2, devices_3)), strat.sampled_from((ReduceOps.SUM, ReduceOps.MAX)),
  126. strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
  127. def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
  128. X = Tensor.rand(N*N).reshape(N, N).mul(sign)
  129. n = X.numpy()
  130. X.shard_(devices, shard_axis)
  131. f = {ReduceOps.SUM: lambda x: x.sum(reduce_axis), ReduceOps.MAX: lambda x: x.max(reduce_axis)}[rop]
  132. fX = f(X)
  133. fn = f(n)
  134. np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
  135. def test_allreduce_naive(self):
  136. with Context(RING=0):
  137. a,b = _test_allreduce(Tensor.rand(256, 256))
  138. np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
  139. def test_allreduce_ring(self):
  140. with Context(RING=2):
  141. a,b = _test_allreduce(Tensor.rand(256, 256))
  142. np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
  143. def test_copy_jit(self):
  144. @TinyJit
  145. def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
  146. for _ in range(5):
  147. t = Tensor.rand(256).realize()
  148. x = copy_tensor(t)
  149. np.testing.assert_equal((t+1).numpy(), x.numpy())
  150. def test_allreduce_naive_jit(self):
  151. with Context(RING=0):
  152. jit_allreduce = TinyJit(_test_allreduce)
  153. for _ in range(5):
  154. a,b = jit_allreduce(Tensor.rand(256, 256))
  155. np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
  156. def test_allreduce_ring_jit(self):
  157. with Context(RING=2):
  158. jit_allreduce = TinyJit(_test_allreduce)
  159. for _ in range(5):
  160. a,b = jit_allreduce(Tensor.rand(256, 256))
  161. np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
  162. @unittest.skip("slow")
  163. def test_fuzz_allreduce(self):
  164. random.seed(41)
  165. for it in range(100):
  166. for n in range(2, 4+1):
  167. shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
  168. t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
  169. with Context(RING=0):
  170. a = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0))
  171. with Context(RING=2):
  172. b = Tensor(MultiLazyBuffer(all_reduce(ReduceOps.SUM, t.lazydata.lbs), 0))
  173. diff = a - b
  174. mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
  175. max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
  176. assert mean_err < 1e-6, f"big mean error, iteration {it}_{n}"
  177. assert max_err < 1e-6, f"big max error, iteration {it}_{n}"
  178. def _test_matmul_shard_axis(self, shard_x, shard_w, device):
  179. X = Tensor.kaiming_uniform(N, N).realize()
  180. W = Tensor.kaiming_uniform(N, N).realize()
  181. Xs = X.shard(device, shard_x)
  182. Ws = W.shard(device, shard_w)
  183. O = (Xs@Ws)
  184. np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
  185. def _test_double_matmul_shard_axis(self, shard_x, shard_w, device):
  186. X = Tensor.kaiming_uniform(N, N).realize()
  187. W1 = Tensor.kaiming_uniform(N, N).realize()
  188. W2 = Tensor.kaiming_uniform(N, N).realize()
  189. Xs = X.shard(device, shard_x)
  190. W1s = W1.shard(device, shard_w)
  191. W2s = W2.shard(device, shard_w)
  192. O = (Xs@W1s)@W2s
  193. np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
  194. def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None, devices_2)
  195. def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None, devices_2)
  196. def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None, devices_2)
  197. def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0, devices_2)
  198. def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1, devices_2)
  199. def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0, devices_2)
  200. def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1, devices_2)
  201. def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0, devices_2)
  202. def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1, devices_2)
  203. def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None, devices_2)
  204. def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None, devices_2)
  205. def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0, devices_2)
  206. def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1, devices_2)
  207. def test_conv_data_shard(self):
  208. conv = nn.Conv2d(3, 16, 3, bias=False)
  209. for p in get_parameters(conv): p.shard_(devices_2)
  210. fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
  211. out = conv(fake_image)
  212. out.numpy()
  213. def test_conv_bias_data_shard(self):
  214. conv = nn.Conv2d(3, 16, 3)
  215. for p in get_parameters(conv): p.shard_(devices_2)
  216. fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
  217. out = conv(fake_image)
  218. out.numpy()
  219. def test_backprop_conv(self):
  220. with Tensor.train():
  221. conv = nn.Conv2d(3, 16, 3)
  222. for p in get_parameters(conv): p.shard_(devices_2)
  223. optim = nn.optim.Adam(get_parameters(conv))
  224. fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
  225. out = conv(fake_image)
  226. optim.zero_grad()
  227. out.mean().backward()
  228. #for p in get_parameters(conv): p.grad.realize()
  229. optim.step()
  230. out.numpy()
  231. def test_lr_scheduler_OneCycleLR(self):
  232. from extra.lr_scheduler import OneCycleLR
  233. conv = nn.Conv2d(3, 16, 3)
  234. for p in get_parameters(conv): p.shard_(devices_2)
  235. optim = nn.optim.SGD(get_parameters(conv))
  236. lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
  237. lr_sched.step()
  238. def test_embedding(self):
  239. B, T, embed_size, vocab_size = 4, 10, 20, 28
  240. layer = nn.Embedding(vocab_size, embed_size)
  241. x = Tensor(np.random.randint(0, vocab_size, (B, T)))
  242. z = layer(x)
  243. layer_sharded = nn.Embedding(vocab_size, embed_size)
  244. layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=1)).realize()
  245. x_sharded = x.shard(devices_2, axis=None)
  246. z_shard = layer_sharded(x_sharded)
  247. np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
  248. def test_rmsnorm(self):
  249. B, T, embed_size = 4, 10, 20
  250. norm = nn.RMSNorm(embed_size)
  251. x = Tensor.rand((B, T, embed_size)).contiguous().realize()
  252. y = norm(x)
  253. # for norm layers, the correct way to shard weights is duplication
  254. norm_sharded = nn.RMSNorm(embed_size)
  255. norm_sharded.weight.shard_(devices_2, axis=None).realize()
  256. # if x is being sharded, then all-reduce is involved
  257. x_sharded = x.shard(devices_2, axis=2).realize()
  258. y_shard = norm_sharded(x_sharded).realize()
  259. np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
  260. # if x is being duplicated, then the operations remain inside each GPU
  261. # which is the common case
  262. x_sharded = x.shard(devices_2, axis=None).realize()
  263. y_shard = norm_sharded(x_sharded).realize()
  264. np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
  265. # NOTE: this is failing on LLVM CI, no idea why. Works locally.
  266. @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow")
  267. def test_data_parallel_resnet(self):
  268. import sys, pathlib
  269. sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
  270. from resnet import ResNet18
  271. fake_image = Tensor.rand((2, 3, 224//8, 224//8))
  272. fake_image_sharded = fake_image.shard(devices_2, axis=0)
  273. m = ResNet18()
  274. m.load_from_pretrained()
  275. real_output = m(fake_image).log_softmax().numpy()
  276. for p in get_parameters(m): p.shard_(devices_2).realize()
  277. GlobalCounters.reset()
  278. shard_output = m(fake_image_sharded).log_softmax().realize()
  279. assert shard_output.lazydata.lbs[0].shape == (1, 1000)
  280. assert shard_output.lazydata.lbs[1].shape == (1, 1000)
  281. shard_output_np = shard_output.numpy()
  282. np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
  283. @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM")
  284. def test_data_parallel_resnet_train_step(self):
  285. import sys, pathlib
  286. sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
  287. from resnet import ResNet18
  288. from tinygrad.nn.optim import LARS
  289. fake_image = Tensor.rand((2, 3, 224//8, 224//8))
  290. fake_image_sharded = fake_image.shard(devices_2, axis=0)
  291. labels = Tensor.randint(2, low=0, high=1000)
  292. labels_sharded = labels.shard(devices_2, axis=0)
  293. m = ResNet18()
  294. optimizer = LARS(get_parameters(m), 0.1) # set requires_grad for all params
  295. optimizer.zero_grad()
  296. m.load_from_pretrained()
  297. output = m(fake_image).sparse_categorical_crossentropy(labels, label_smoothing=0.1)
  298. output.backward()
  299. grad = m.conv1.weight.grad.numpy()
  300. for p in get_parameters(m): p.shard_(devices_2).realize()
  301. GlobalCounters.reset()
  302. optimizer.zero_grad()
  303. shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1)
  304. assert shard_output.lazydata.axis is None
  305. shard_output.backward()
  306. shard_grad = m.conv1.weight.grad.numpy()
  307. # sometimes there is zeros in these grads... why?
  308. np.testing.assert_allclose(grad, shard_grad, atol=3e-6, rtol=3e-6)
  309. def test_multi_tensor_jit_param(self):
  310. @TinyJit
  311. def jf(a, b) -> Tensor:
  312. return (a + b).realize()
  313. for _ in range(5):
  314. a = Tensor.ones(256).contiguous().realize()
  315. b = Tensor.ones(256).contiguous().realize()
  316. a.shard_(devices_2)
  317. b.shard_(devices_2)
  318. c = jf(a, b)
  319. np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
  320. assert len(jf.jit_cache) > 0
  321. def test_multi_tensor_jit_body(self):
  322. @TinyJit
  323. def jf() -> Tensor:
  324. a = Tensor.ones(256).contiguous().realize()
  325. b = Tensor.ones(256).contiguous().realize()
  326. a.shard_(devices_2)
  327. b.shard_(devices_2)
  328. return (a + b).realize()
  329. for _ in range(5):
  330. r = jf()
  331. np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
  332. assert len(jf.jit_cache) > 0
  333. #@unittest.skipIf(CI and Device.DEFAULT=="METAL", "no ICB in CI, creation of graph fails")
  334. @unittest.skip("test broken")
  335. def test_multi_device_jit_graph(self):
  336. if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")
  337. @TinyJit
  338. def jf(a: Tensor, b: Tensor, c: Tensor, d:Tensor):
  339. # Create 80 entries on device 0: 2 batches.
  340. for _ in range(40):
  341. a = ((a + b).realize() + (a * b).realize()).realize()
  342. # Create 80 entries on device 1: 2 batches.
  343. for _ in range(40):
  344. c = ((c + d).realize() + (c * d).realize()).realize()
  345. # Create a copy from device 0 to 1: 1 entry.
  346. a = a.to(d1).realize()
  347. # Creates one last entry on device 1: 1 batch.
  348. return (a + c).realize()
  349. a = Tensor.randn(10, 10, device=d0).realize()
  350. b = Tensor.randn(10, 10, device=d0).realize()
  351. c = Tensor.randn(10, 10, device=d1).realize()
  352. d = Tensor.randn(10, 10, device=d1).realize()
  353. ref = jf(a, b, c, d).numpy()
  354. for _ in range(5):
  355. o = jf(a, b, c, d).numpy()
  356. np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5)
  357. graph_d0 = Device[d0].graph.func if isinstance(Device[d0].graph, functools.partial) else Device[d0].graph
  358. graph_d1 = Device[d1].graph.func if isinstance(Device[d1].graph, functools.partial) else Device[d1].graph
  359. # Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created.
  360. assert isinstance(jf.jit_cache[0].prg, graph_d0)
  361. assert isinstance(jf.jit_cache[1].prg, graph_d0)
  362. assert isinstance(jf.jit_cache[2].prg, graph_d1)
  363. assert isinstance(jf.jit_cache[3].prg, graph_d1)
  364. assert isinstance(jf.jit_cache[4].prg, BufferCopy)
  365. assert isinstance(jf.jit_cache[5].prg, graph_d1)
  366. def test_uneven_shard(self):
  367. for N in range(1, 6):
  368. X = Tensor.rand(4, 1, 257).contiguous().realize()
  369. n = X.numpy()
  370. devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
  371. X.shard_(devices, 2)
  372. np.testing.assert_equal(X.numpy(), n)
  373. np.testing.assert_equal(X.reshape(2, 2, 257).numpy(), n.reshape((2, 2, 257)))
  374. np.testing.assert_equal(X.shrink(((0,2), (0, 1), (0,257))).numpy(), n[0:2, 0:1, 0:257])
  375. np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
  376. np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
  377. def test_uneven_multiple_zeros(self):
  378. for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
  379. for N in (1, 2, 3, 4):
  380. devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
  381. # make sure something is computed on each device
  382. X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
  383. np.testing.assert_equal(X.numpy(), data)
  384. def test_bn_ast_on_devices(self):
  385. t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
  386. bn = nn.BatchNorm2d(64)
  387. for p in get_parameters(bn): p.shard_(devices_4).realize()
  388. out = bn(t)
  389. scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not MetaOps.COPY]
  390. assert set(out.device for sched in scheds for out in sched.outputs) == set(devices_4), "should have ast on each shard device"
  391. asts = [sched.ast for sched in scheds]
  392. assert len(asts)
  393. # test case to show that ast can be different on devices
  394. # TODO: make ast identical on devices
  395. #assert len(set(asts)) == 4, len(asts)
  396. # for i, ast in enumerate(asts):
  397. # print(f"{i} {ast}")
  398. def test_reshape_on_axis(self):
  399. t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1)
  400. # test split and rejoin to the right
  401. t1 = t0.reshape((26, 3, 5, 7))
  402. t2 = t0.reshape((26, 3, 35))
  403. t3 = t1.reshape((26, 15, 7))
  404. t4 = t2.reshape((26, 105,))
  405. for t in [t0, t1, t2, t3, t4]:
  406. assert t.lazydata.axis == 1
  407. np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
  408. # test shape-one axis
  409. t5 = t4.reshape((26, 1, 105))
  410. assert t5.lazydata.axis == 2
  411. np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
  412. # test split and rejoin to the right and reshape to the left
  413. t5 = t0.reshape((2, 13, 3, 5, 7))
  414. t6 = t0.reshape((13, 2, 3, 7, 5))
  415. t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
  416. np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
  417. assert t5.lazydata.axis == 2
  418. np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
  419. assert t6.lazydata.axis == 2
  420. np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
  421. assert t7.lazydata.axis == 3
  422. # test no left join
  423. with self.assertRaises((AssertionError, ValueError)):
  424. t0.reshape((26*15,7))
  425. def test_reshape_on_axis_uneven(self):
  426. t0 = Tensor.rand((4, 8, 15)).shard(devices_3, axis=1)
  427. # no split axis if uneven
  428. with self.assertRaises((AssertionError, ValueError)):
  429. t0.reshape((4,4,2,15))
  430. # ok to split reshape left and right though
  431. t1 = t0.reshape(2, 2, 8, 3, 5)
  432. np.testing.assert_allclose(t0.numpy().flatten(), t1.numpy().flatten())
  433. assert t1.lazydata.axis == 2
  434. def test_mlb_assign_change_axis(self):
  435. t_none = Tensor.zeros((16, 16)).shard(devices_2).contiguous().realize()
  436. t_zero = Tensor.ones((16, 16)).shard(devices_2, axis=0)
  437. with self.assertRaises(AssertionError):
  438. # don't allow assigns that change axes
  439. t_none.assign(t_zero)
  440. def test_dropout_on_shard(self):
  441. with Tensor.train():
  442. X = Tensor.ones(256).to(devices_2)
  443. output = X.dropout(0.5).numpy()
  444. unique, counts = np.unique(output, return_counts=True)
  445. assert set(unique) == {0, 2}, unique
  446. assert 100 < counts[0] < 156, counts[0]
  447. def test_broadcast_const(self):
  448. for axis in (None, 0, 1):
  449. t = Tensor.zeros(16, 16).contiguous().shard(devices_4, axis).realize()
  450. t = t + 1
  451. for si in t.schedule():
  452. ast = si.ast.src[0]
  453. assert ast.op is BufferOps.STORE
  454. assert ast.src[0].op is BinaryOps.ADD
  455. assert ast.src[0].src[0].op is BufferOps.LOAD and ast.src[0].src[0]
  456. assert ast.src[0].src[1].op is BufferOps.CONST and ast.src[0].src[1].arg.val == 1
  457. t = 2 * t
  458. for si in t.schedule():
  459. ast = si.ast.src[0]
  460. assert ast.op is BufferOps.STORE
  461. assert ast.src[0].op is BinaryOps.MUL
  462. assert ast.src[0].src[0].op is BufferOps.CONST and ast.src[0].src[0].arg.val == 2
  463. assert ast.src[0].src[1].op is BufferOps.LOAD
  464. t = t + t.full_like(3)
  465. for si in t.schedule():
  466. ast = si.ast.src[0]
  467. assert ast.op is BufferOps.STORE
  468. assert ast.src[0].op is BinaryOps.ADD
  469. assert ast.src[0].src[0].op is BufferOps.LOAD
  470. assert ast.src[0].src[1].op is BufferOps.CONST and ast.src[0].src[1].arg.val == 3
  471. def test_shard_memory(self):
  472. devices = (d0, d1, d2, d3)
  473. t = Tensor.zeros(16, 16).contiguous()
  474. t.shard_(devices, axis=0)
  475. assert all([lb is lb.base and lb.buffer.base.size == 4 * 16 for lb in t.lazydata.lbs])
  476. @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
  477. class TestHandleData(unittest.TestCase):
  478. def test_copied_to_device(self):
  479. device = (d0, d1, d2, d3)
  480. t = Tensor([1, 2, 3, 4]).shard(device).realize()
  481. not_covered = t.to(d5)
  482. sched = create_schedule([not_covered.lazydata])
  483. assert len(sched) == 1
  484. # setup again because create_schedule has side effect
  485. t = Tensor([1, 2, 3, 4]).shard(device).realize()
  486. not_covered = t.to(d5)
  487. assert not_covered.realize().tolist() == [1, 2, 3, 4]
  488. for d in device:
  489. t = Tensor([1, 2, 3, 4]).shard(device).realize()
  490. covered = t.to(d)
  491. sched = create_schedule([covered.lazydata])
  492. assert len(sched) == 0
  493. # setup again because create_schedule has side effect
  494. t = Tensor([1, 2, 3, 4]).shard(device).realize()
  495. covered = t.to(d)
  496. assert covered.realize().tolist() == [1, 2, 3, 4]
  497. @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
  498. class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
  499. # shrink a multitensor on sharded axis
  500. def test_shrink_bad_args(self):
  501. t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
  502. t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
  503. with self.assertRaises(AssertionError):
  504. # sharded axis shrink on non-device boundry is not allowed
  505. a = t.shrink(((0, 3), (0, 8)))
  506. with self.assertRaises(AssertionError):
  507. # cannot shrink sharded and non-sharded axis at the same time
  508. a = t.shrink(((0, 2), (2, 4)))
  509. a = t.shrink(((0, 2), (0, 8)))
  510. assert a.shape == (2, 8)
  511. assert a.lazydata.real == [True, False, False, False]
  512. with self.assertRaises(AssertionError):
  513. # cannot pad sharded and non-sharded axis at the same time
  514. p = a.pad(((0, 6), (0, 1)))
  515. with self.assertRaises(AssertionError):
  516. # can only pad to whole axis
  517. p = a.pad(((1, 5), (0, 0)))
  518. p = a.pad(((0, 6), (0, 0)))
  519. assert p.shape == (8, 8)
  520. assert p.lazydata.real == [True, True, True, True]
  521. @given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
  522. def test_ops(self, dtype):
  523. if not is_dtype_supported(dtype): return
  524. t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
  525. t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
  526. for i in range(4):
  527. print(f"{i=}")
  528. a = t.shrink(((0+2*i,2+2*i),None))
  529. b = Tensor(t.numpy()[0+2*i:2+2*i])
  530. assert a.shape == b.shape == (2, 8)
  531. assert a.lazydata.real == [i==j for j in range(4)]
  532. np.testing.assert_allclose(a.numpy(), b.numpy())
  533. # cast
  534. np.testing.assert_allclose(a.float().numpy(), b.float().numpy())
  535. # elementwise
  536. np.testing.assert_allclose(a.exp().numpy(), b.exp().numpy(), rtol=1e-7, atol=1e-3)
  537. np.testing.assert_allclose(a.reciprocal().numpy(), b.reciprocal().numpy(), rtol=1e-7, atol=1e-3)
  538. np.testing.assert_allclose(a.pow(-0.5).numpy(), b.pow(-0.5).numpy(), rtol=1e-7, atol=1e-3)
  539. np.testing.assert_allclose((a+a).numpy(), (b+b).numpy(), rtol=1e-7, atol=1e-3)
  540. np.testing.assert_equal((a+1).numpy(), (b+1).numpy())
  541. np.testing.assert_equal((1+a).numpy(), (1+b).numpy())
  542. np.testing.assert_allclose((a.where(a+a, a)).numpy(), (b.where(b+b, b)).numpy(), rtol=1e-7, atol=1e-3)
  543. np.testing.assert_allclose((a.where(1, 0)).numpy(), (b.where(1, 0)).numpy(), rtol=1e-7, atol=1e-3)
  544. # reduce
  545. np.testing.assert_allclose(a.max().numpy(), b.max().numpy(), rtol=1e-7, atol=1e-3)
  546. np.testing.assert_allclose(a.sum().numpy(), b.sum().numpy(), rtol=1e-7, atol=1e-3)
  547. np.testing.assert_allclose(a.mean().numpy(), b.mean().numpy(), rtol=1e-7, atol=1e-3)
  548. np.testing.assert_allclose(a.max(0).numpy(), b.max(0).numpy(), rtol=1e-7, atol=1e-3)
  549. np.testing.assert_allclose(a.sum(0).numpy(), b.sum(0).numpy(), rtol=1e-7, atol=1e-3)
  550. np.testing.assert_allclose(a.mean(0).numpy(), b.mean(0).numpy(), rtol=1e-7, atol=1e-3)
  551. np.testing.assert_allclose(a.max(1).numpy(), b.max(1).numpy(), rtol=1e-7, atol=1e-3)
  552. np.testing.assert_allclose(a.sum(1).numpy(), b.sum(1).numpy(), rtol=1e-7, atol=1e-3)
  553. np.testing.assert_allclose(a.mean(1).numpy(), b.mean(1).numpy(), rtol=1e-7, atol=1e-3)
  554. # pad it back
  555. np.testing.assert_allclose(a.pad(((2*i, 2*(4-i-1)), None)).numpy(), b.pad(((2*i, 2*(4-i-1)), None)).numpy(), rtol=1e-7, atol=1e-3)
  556. # other movement
  557. np.testing.assert_allclose(a.pad((None, (1, 1))).numpy(), b.pad((None, (1, 1))).numpy(), rtol=1e-7, atol=1e-3)
  558. np.testing.assert_allclose(a.shrink((None, (1, 3))).numpy(), b.shrink((None, (1, 3))).numpy(), rtol=1e-7, atol=1e-3)
  559. np.testing.assert_allclose(a.permute((1, 0)).numpy(), b.permute((1, 0)).numpy(), rtol=1e-7, atol=1e-3)
  560. np.testing.assert_allclose(a.reshape((2, 2, 4)).numpy(), b.reshape((2, 2, 4)).numpy(), rtol=1e-7, atol=1e-3)
  561. np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
  562. np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
  563. def test_uneven(self):
  564. t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
  565. t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
  566. a = t.shrink(((0, 2), None))
  567. b = t.shrink(((2, 3), None))
  568. na = t.numpy()[0:2]
  569. nb = t.numpy()[2:3]
  570. np.testing.assert_equal(a.numpy(), na)
  571. np.testing.assert_equal(b.numpy(), nb)
  572. np.testing.assert_equal((a+1).numpy(), na+1)
  573. np.testing.assert_equal((b+1).numpy(), nb+1)
  574. np.testing.assert_equal((1+a).numpy(), 1+na)
  575. np.testing.assert_equal((1+b).numpy(), 1+nb)
  576. np.testing.assert_equal((a+a).numpy(), na+na)
  577. np.testing.assert_equal((b+b).numpy(), nb+nb)
  578. def test_add_two_partitions(self):
  579. t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
  580. t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
  581. a = t.shrink(((2, 4), None))
  582. b = t.shrink(((6, 8), None))
  583. na = t.numpy()[2:4]
  584. nb = t.numpy()[6:8]
  585. np.testing.assert_equal(a.numpy(), na)
  586. np.testing.assert_equal(b.numpy(), nb)
  587. with self.assertRaises(AssertionError):
  588. # cannot add directly
  589. c = a + b
  590. c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
  591. expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
  592. np.testing.assert_equal(c.numpy(), expected)
  593. def test_add_different_tensors(self):
  594. devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
  595. x = Tensor.arange(64).reshape(8, 8).contiguous().realize().shard(devices, axis=0)
  596. to_add = []
  597. for i in range(len(devices)):
  598. to_add.append((Tensor.ones(2, 8) * i).shard(devices))
  599. added:List[Tensor] = []
  600. for bound, a in zip(x.lazydata.bounds, to_add):
  601. added.append(x[bound[0]:bound[1]] + a)
  602. output = added[0].cat(*added[1:])
  603. expected = np.arange(64).reshape((8,8)) + np.array([[0,0,1,1,2,2,3,3] for _ in range(8)]).T
  604. np.testing.assert_allclose(output.numpy(), expected)
  605. @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
  606. class TestBatchNorm(unittest.TestCase):
  607. def test_unsynced_backprop_conv_bn(self):
  608. with Tensor.train():
  609. from extra.lr_scheduler import OneCycleLR
  610. convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
  611. bns = [nn.BatchNorm2d(16), nn.BatchNorm2d(16)]
  612. for p in get_parameters(convs + bns):
  613. p.shard_((d1, d2))
  614. optim = nn.optim.Adam(get_parameters(convs + bns))
  615. lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
  616. lr_sched.step()
  617. fake_image = Tensor.rand((8, 3, 32, 32)).shard((d1, d2), axis=0)
  618. f1 = fake_image.shrink(((0, 4), None, None, None))
  619. f2 = fake_image.shrink(((4, 8), None, None, None))
  620. out1 = bns[0](convs[0](f1))
  621. out2 = bns[1](convs[1](f2))
  622. out = out1.cat(out2)
  623. optim.zero_grad()
  624. out.mean().backward()
  625. optim.step()
  626. out.numpy()
  627. def test_unsynced_backprop_standalone_bn(self):
  628. from extra.lr_scheduler import OneCycleLR
  629. GPUS = (d1, d2)
  630. class BatchNorm:
  631. def __init__(self, num_features):
  632. self.bns:List[nn.BatchNorm2d] = []
  633. for _ in GPUS:
  634. bn = nn.BatchNorm2d(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
  635. self.bns.append(bn)
  636. def __call__(self, x:Tensor):
  637. bn_ts = []
  638. for bound, bn in zip(x.lazydata.bounds, self.bns):
  639. xi = x.shrink((bound, None, None, None))
  640. bni = bn(xi)
  641. bn_ts.append(bni)
  642. return bn_ts[0].cat(*bn_ts[1:])
  643. with Tensor.train():
  644. conv = nn.Conv2d(3, 16, 3)
  645. bn = BatchNorm(16)
  646. for p in get_parameters([conv, bn]):
  647. p.shard_(GPUS)
  648. optim = nn.optim.Adam(get_parameters([conv, bn]))
  649. lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
  650. lr_sched.step()
  651. fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
  652. out = bn(conv(fake_image))
  653. optim.zero_grad()
  654. out.mean().backward()
  655. optim.step()
  656. def test_unsynced_backprop_sync_weights(self):
  657. from extra.lr_scheduler import OneCycleLR
  658. from examples.hlb_cifar10 import UnsyncedBatchNorm
  659. GPUS = (d1, d2)
  660. with Tensor.train():
  661. conv = nn.Conv2d(3, 16, 3)
  662. bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
  663. for k, p in get_state_dict([conv, bn]).items():
  664. if 'running_mean' in k or 'running_var' in k:
  665. p.shard_(GPUS, axis=0)
  666. else:
  667. p.to_(GPUS)
  668. optim = nn.optim.Adam(get_parameters([conv, bn]))
  669. lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
  670. lr_sched.step()
  671. fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
  672. out = bn(conv(fake_image))
  673. optim.zero_grad()
  674. out.mean().backward()
  675. optim.step()
  676. @given(strat.sampled_from((False, True)))
  677. def test_batchnorm(self, is_training):
  678. devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
  679. x = Tensor.arange(4096).reshape(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
  680. with Tensor.train(is_training):
  681. bns = []
  682. for _ in range(len(devices)):
  683. bn = nn.BatchNorm2d(8)
  684. for p in get_parameters(bn):
  685. p.shard_(devices)
  686. bn.weight.requires_grad = True
  687. bn.bias.requires_grad = True
  688. bns.append(bn)
  689. bn_ts = []
  690. for bound, bn in zip(x.lazydata.bounds, bns):
  691. bni = bn(x[bound[0]:bound[1]])
  692. bn_ts.append(bni)
  693. bn_ts[0].cat(*bn_ts[1:]).numpy()
  694. def test_synced_vs_unsynced_bn(self):
  695. from examples.hlb_cifar10 import UnsyncedBatchNorm
  696. from tinygrad.nn import BatchNorm2d
  697. devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
  698. x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
  699. with Tensor.train():
  700. synced_bn = BatchNorm2d(8)
  701. unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
  702. for p in get_parameters(synced_bn):
  703. p.shard_(devices)
  704. for k, p in get_state_dict(unsynced_bn).items():
  705. if 'running_mean' in k or 'running_var' in k:
  706. p.shard_(devices, axis=0)
  707. else:
  708. p.to_(devices)
  709. synced_out = synced_bn(x)
  710. synced_si = list(create_schedule(synced_out.lazydata.lbs))
  711. unsynced_out = unsynced_bn(x)
  712. unsynced_si = list(create_schedule(unsynced_out.lazydata.lbs))
  713. # TODO: test synced / unsynced batchnorm cross device kernel and copies
  714. assert synced_si
  715. assert unsynced_si
  716. if __name__ == '__main__':
  717. unittest.main()