test_nn.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. #!/usr/bin/env python
  2. import unittest
  3. import numpy as np
  4. import torch
  5. from tinygrad import Tensor, Device, TinyJit
  6. from tinygrad.helpers import CI, Context
  7. from tinygrad.ops import MetaOps
  8. from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
  9. from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm
  10. from tinygrad.nn.state import load_state_dict
  11. from tinygrad.engine.schedule import create_schedule
  12. from tinygrad.engine.realize import run_schedule
  13. @unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
  14. class TestNN(unittest.TestCase):
  15. @unittest.skipIf(Device.DEFAULT == "WEBGPU", "no int64 on WebGPU")
  16. def test_sparse_cat_cross_entropy(self):
  17. # create in tinygrad
  18. input_tensor = Tensor.randn(5, 5)
  19. target = Tensor([0, 0, 0, 1, 2]) # torch doesn't support target=-1
  20. torch_input = torch.tensor(input_tensor.numpy())
  21. torch_target = torch.tensor(target.numpy(), dtype=torch.long)
  22. for smoothing in [0.0, 0.1, 0.5, 1.0]:
  23. for ignore_index in [-1, 0, 2]:
  24. loss = input_tensor.sparse_categorical_crossentropy(target, label_smoothing=smoothing, ignore_index=ignore_index)
  25. torch_loss = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=smoothing, ignore_index=ignore_index)(torch_input, torch_target)
  26. np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
  27. def test_batchnorm2d(self, training=False, threed=False):
  28. with Tensor.train(training):
  29. szs = [4, 8, 16, 32]
  30. for sz in szs:
  31. # create in tinygrad
  32. bn = BatchNorm(sz, eps=1e-5, track_running_stats=training)
  33. bn.weight = Tensor.randn(sz)
  34. bn.bias = Tensor.randn(sz)
  35. bn.running_mean = Tensor.randn(sz)
  36. bn.running_var = Tensor.randn(sz)
  37. bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
  38. # create in torch
  39. with torch.no_grad():
  40. if threed:
  41. tbn = torch.nn.BatchNorm3d(sz).eval()
  42. else:
  43. tbn = torch.nn.BatchNorm2d(sz).eval()
  44. tbn.training = training
  45. tbn.weight[:] = torch.tensor(bn.weight.numpy())
  46. tbn.bias[:] = torch.tensor(bn.bias.numpy())
  47. tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
  48. tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
  49. np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
  50. np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
  51. # trial
  52. if threed:
  53. inn = Tensor.randn(2, sz, 3, 3, 3)
  54. else:
  55. inn = Tensor.randn(2, sz, 3, 3)
  56. # in tinygrad
  57. outt = bn(inn)
  58. # in torch
  59. toutt = tbn(torch.tensor(inn.numpy()))
  60. # close
  61. np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6)
  62. np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
  63. np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
  64. def test_batchnorm2d_training(self):
  65. self.test_batchnorm2d(True)
  66. def test_batchnorm3d(self): self.test_batchnorm2d(False, True)
  67. def test_batchnorm3d_training(self): self.test_batchnorm2d(True, True)
  68. def test_batchnorm_axis(self):
  69. sz = (2, 4, 3, 2, 2)
  70. x = Tensor.randn(sz)
  71. weight = Tensor.randn(2, 3)
  72. bias = Tensor.randn(2, 3)
  73. mean = Tensor.randn(2, 3)
  74. invstd = Tensor.randn(2, 3)
  75. a = (x.batchnorm(weight, bias, mean, invstd, axis=(0, 2))
  76. .permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2))
  77. b = (x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2)
  78. .batchnorm(weight.flatten(), bias.flatten(), mean.flatten(), invstd.flatten()))
  79. t_x = torch.tensor(x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2).numpy())
  80. t_weight, t_bias = torch.tensor(weight.flatten().numpy()), torch.tensor(bias.flatten().numpy())
  81. t_mean, t_invstd = torch.tensor(mean.flatten().numpy()), torch.tensor(invstd.flatten().numpy())
  82. torch.nn.functional.batch_norm(t_x, t_mean, 1.0 / t_invstd**2, t_weight, t_bias)
  83. np.testing.assert_allclose(a.numpy(), b.numpy())
  84. def test_linear(self):
  85. def _test_linear(x, in_dim, out_dim):
  86. # create in tinygrad
  87. model = Linear(in_dim, out_dim)
  88. z = model(x)
  89. # create in torch
  90. with torch.no_grad():
  91. torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
  92. torch_layer.weight[:] = torch.tensor(model.weight.numpy(), dtype=torch.float32)
  93. torch_layer.bias[:] = torch.tensor(model.bias.numpy(), dtype=torch.float32)
  94. torch_x = torch.tensor(x.numpy(), dtype=torch.float32)
  95. torch_z = torch_layer(torch_x)
  96. # test
  97. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
  98. BS, T, in_dim, out_dim = 4, 2, 8, 16
  99. _test_linear(Tensor.randn(BS, in_dim), in_dim, out_dim)
  100. _test_linear(Tensor.randn(BS, T, in_dim), in_dim, out_dim) # test with more dims
  101. def test_conv1d(self):
  102. BS, C1, W = 4, 16, 224//4
  103. C2, K, S, P = 64, 7, 2, 1
  104. # create in tinygrad
  105. layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
  106. # create in torch
  107. with torch.no_grad():
  108. torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
  109. torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
  110. torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
  111. # test
  112. x = Tensor.uniform(BS, C1, W)
  113. z = layer(x)
  114. torch_x = torch.tensor(x.numpy())
  115. torch_z = torch_layer(torch_x)
  116. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
  117. def test_conv2d(self):
  118. BS, C1, H, W = 4, 16, 224//4, 224//4
  119. C2, K, S, P = 64, 7, 2, 1
  120. # create in tinygrad
  121. layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
  122. # create in torch
  123. with torch.no_grad():
  124. torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
  125. torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
  126. torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
  127. # test
  128. x = Tensor.uniform(BS, C1, H, W)
  129. z = layer(x)
  130. torch_x = torch.tensor(x.numpy())
  131. torch_z = torch_layer(torch_x)
  132. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
  133. @unittest.skip("Takes too long to compile for Compiled backends")
  134. def test_conv2d_winograd(self):
  135. BS, C1, H, W = 2, 8, 16, 16
  136. C2, K, S, P = 8, 3, 1, 1
  137. # create in tinygrad
  138. layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
  139. layer.weight.requires_grad = True
  140. layer.bias.requires_grad = True
  141. # create in torch
  142. torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
  143. torch_layer.weight = torch.nn.Parameter(torch.tensor(layer.weight.numpy(), dtype=torch.float32))
  144. torch_layer.bias = torch.nn.Parameter(torch.tensor(layer.bias.numpy(), dtype=torch.float32))
  145. # test
  146. x = Tensor.uniform(BS, C1, H, W, requires_grad=True)
  147. with Context(WINO=1):
  148. z = layer(x)
  149. torch_x = torch.tensor(x.numpy(), requires_grad=True)
  150. torch_z = torch_layer(torch_x)
  151. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
  152. m = z.mean()
  153. m.backward()
  154. gw = layer.weight.grad.realize()
  155. gb = layer.bias.grad.realize()
  156. gx = x.grad.realize()
  157. torch_z.mean().backward()
  158. np.testing.assert_allclose(gw.numpy(), torch_layer.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
  159. np.testing.assert_allclose(gb.numpy(), torch_layer.bias.grad.numpy(), atol=5e-4, rtol=1e-5)
  160. np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5)
  161. @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI")
  162. def test_conv_transpose1d(self):
  163. BS, C1, W = 4, 16, 224//4
  164. C2, K, S, P = 64, 7, 2, 1
  165. # create in tinygrad
  166. layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P)
  167. # create in torch
  168. with torch.no_grad():
  169. torch_layer = torch.nn.ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
  170. torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
  171. torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
  172. # test
  173. x = Tensor.uniform(BS, C1, W)
  174. z = layer(x)
  175. torch_x = torch.tensor(x.numpy())
  176. torch_z = torch_layer(torch_x)
  177. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
  178. @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI")
  179. def test_conv_transpose2d(self):
  180. BS, C1, H, W = 4, 16, 224//4, 224//4
  181. C2, K, S, P = 64, 7, 2, 1
  182. # create in tinygrad
  183. layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P)
  184. # create in torch
  185. with torch.no_grad():
  186. torch_layer = torch.nn.ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
  187. torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
  188. torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
  189. # test
  190. x = Tensor.uniform(BS, C1, H, W)
  191. z = layer(x)
  192. torch_x = torch.tensor(x.numpy())
  193. torch_z = torch_layer(torch_x)
  194. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
  195. def test_groupnorm(self):
  196. BS, H, W, C, G = 20, 10, 10, 6, 3
  197. # create in torch
  198. torch_layer = torch.nn.GroupNorm(G, C).eval()
  199. # create in tinygrad
  200. layer = GroupNorm(G, C)
  201. layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
  202. layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
  203. for _ in range(10):
  204. # forward
  205. x = Tensor.randn(BS, C, H, W, requires_grad=True)
  206. z = layer(x)
  207. torch_x = torch.tensor(x.numpy(), requires_grad=True)
  208. torch_z = torch_layer(torch_x)
  209. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
  210. # backward
  211. z.sum().backward()
  212. torch_z.sum().backward(retain_graph=True)
  213. np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  214. np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  215. np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  216. def test_layernorm(self):
  217. N, C, H, W = 20, 5, 10, 10
  218. # create in torch
  219. torch_layer = torch.nn.LayerNorm([H, W]).eval()
  220. # create in tinygrad
  221. layer = LayerNorm([H, W])
  222. layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
  223. layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
  224. for _ in range(10):
  225. # forward
  226. x = Tensor.randn(N, C, H, W, requires_grad=True)
  227. z = layer(x)
  228. torch_x = torch.tensor(x.numpy(), requires_grad=True)
  229. torch_z = torch_layer(torch_x)
  230. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
  231. # backward
  232. z.sum().backward()
  233. torch_z.sum().backward(retain_graph=True)
  234. np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  235. np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  236. np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  237. def test_layernorm_2d(self):
  238. N, C, H, W = 20, 5, 10, 10
  239. # create in torch
  240. torch_layer = torch.nn.LayerNorm([C]).eval()
  241. # create in tinygrad
  242. layer = LayerNorm2d(C)
  243. layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
  244. layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
  245. for _ in range(10):
  246. # forward
  247. x = Tensor.randn(N, C, H, W, requires_grad=True)
  248. z = layer(x)
  249. torch_x = torch.tensor(x.numpy(), requires_grad=True)
  250. torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
  251. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
  252. # backward
  253. z.sum().backward()
  254. torch_z.sum().backward(retain_graph=True)
  255. np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  256. np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  257. np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
  258. def test_instancenorm_2d(self):
  259. N, C, H, W = 20, 10, 10, 10
  260. # create in torch
  261. torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
  262. # create in tinygrad
  263. layer = InstanceNorm(C)
  264. layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
  265. layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
  266. for _ in range(10):
  267. # forward
  268. x = Tensor.randn(N, C, H, W, requires_grad=True)
  269. z = layer(x)
  270. torch_x = torch.tensor(x.numpy(), requires_grad=True)
  271. torch_z = torch_layer(torch_x)
  272. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
  273. # backward
  274. z.sum().backward()
  275. torch_z.sum().backward(retain_graph=True)
  276. np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
  277. np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
  278. np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
  279. def test_instancenorm_3d(self):
  280. N, C, D, H, W = 20, 10, 10, 10, 10
  281. # create in torch
  282. torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
  283. # create in tinygrad
  284. layer = InstanceNorm(C)
  285. layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
  286. layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
  287. for _ in range(10):
  288. # forward
  289. x = Tensor.randn(N, C, D, H, W, requires_grad=True)
  290. z = layer(x)
  291. torch_x = torch.tensor(x.numpy(), requires_grad=True)
  292. torch_z = torch_layer(torch_x)
  293. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
  294. # backward
  295. z.sum().backward()
  296. torch_z.sum().backward(retain_graph=True)
  297. np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
  298. np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
  299. np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
  300. def test_rmsnorm(self):
  301. class TorchRMSNorm(torch.nn.Module):
  302. # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
  303. def __init__(self, dim: int, eps: float = 1e-6):
  304. super().__init__()
  305. self.eps = eps
  306. self.weight = torch.nn.Parameter(torch.ones(dim))
  307. def _norm(self, x):
  308. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  309. def forward(self, x):
  310. output = self._norm(x.float()).type_as(x)
  311. return output * self.weight
  312. B, T, embed_size = 4, 10, 20
  313. torch_layer = TorchRMSNorm(embed_size)
  314. layer = RMSNorm(embed_size)
  315. layer.weight.requires_grad = True
  316. for _ in range(10):
  317. # forward
  318. x = Tensor.randn(B, T, embed_size, requires_grad=True)
  319. z = layer(x)
  320. torch_x = torch.tensor(x.numpy(), requires_grad=True)
  321. torch_z = torch_layer(torch_x)
  322. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
  323. # backward
  324. z.sum().backward()
  325. torch_z.sum().backward(retain_graph=True)
  326. np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
  327. np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
  328. def test_embedding(self):
  329. B, T, embed_size, vocab_size = 4, 10, 20, 28
  330. # create in tinygrad
  331. layer = Embedding(vocab_size, embed_size)
  332. with torch.no_grad():
  333. torch_layer = torch.nn.Embedding(vocab_size, embed_size).eval()
  334. torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
  335. # test
  336. x = Tensor(np.random.randint(0, vocab_size, (B, T)))
  337. z = layer(x)
  338. torch_x = torch.tensor(x.numpy())
  339. torch_z = torch_layer(torch_x)
  340. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
  341. # test with empty input length
  342. x = Tensor(np.random.randint(0, vocab_size, (B, 0)))
  343. z = layer(x)
  344. torch_x = torch.tensor(x.numpy())
  345. torch_z = torch_layer(torch_x)
  346. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
  347. # test with jit enabled
  348. @TinyJit
  349. def layer_jit(x):
  350. return layer(x).realize()
  351. for _ in range(3):
  352. x = Tensor(np.random.randint(0, vocab_size, (B, T)))
  353. z = layer_jit(x)
  354. torch_x = torch.tensor(x.numpy())
  355. torch_z = torch_layer(torch_x)
  356. np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
  357. def test_embedding_one_kernel(self):
  358. layer = Embedding(20, 30)
  359. a = Tensor([[1, 5, 9, 11],
  360. [12, 19, 8, 1]])
  361. result = layer(a)
  362. schedule = create_schedule([result.lazydata])
  363. self.assertEqual(3, len([item for item in schedule if item.ast.op is MetaOps.KERNEL]), "first run realizes arange, weight, and embedding")
  364. run_schedule(schedule)
  365. b = Tensor([[1, 2, 3],
  366. [4, 5, 6],
  367. [7, 8, 9]])
  368. result = layer(b)
  369. schedule = create_schedule([result.lazydata])
  370. self.assertEqual(1, len([item for item in schedule if item.ast.op is MetaOps.KERNEL]), "second run realizes embedding only")
  371. run_schedule(schedule)
  372. def test_load_state_dict(self):
  373. layer = Conv2d(3, 5, kernel_size=3)
  374. state_dict = {
  375. 'weight': Tensor.randn(5, 3, 3, 3),
  376. 'bias': Tensor.randn(5),
  377. }
  378. load_state_dict(layer, state_dict)
  379. np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
  380. np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
  381. @unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
  382. def test_load_state_dict_sharded(self):
  383. devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
  384. layer = Conv2d(3, 5, kernel_size=3)
  385. layer.weight.shard_(devices, -1)
  386. layer.bias.shard_(devices, None)
  387. state_dict = {
  388. 'weight': Tensor.randn(5, 3, 3, 3).shard(devices, -1),
  389. 'bias': Tensor.randn(5).shard(devices, None),
  390. }
  391. load_state_dict(layer, state_dict)
  392. self.assertEqual(layer.weight.device, devices)
  393. self.assertEqual(layer.bias.device, devices)
  394. np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
  395. np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
  396. if __name__ == '__main__':
  397. unittest.main()