test_disk_tensor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import pathlib, tempfile, unittest
  2. import numpy as np
  3. from tinygrad import Tensor, Device, dtypes
  4. from tinygrad.dtype import DType
  5. from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
  6. from tinygrad.helpers import Timing, fetch, temp, CI
  7. from test.helpers import is_dtype_supported
  8. def compare_weights_both(url):
  9. import torch
  10. fn = fetch(url)
  11. tg_weights = get_state_dict(torch_load(fn))
  12. torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu')), tensor_type=torch.Tensor)
  13. assert list(tg_weights.keys()) == list(torch_weights.keys())
  14. for k in tg_weights:
  15. if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
  16. if torch_weights[k].dtype == torch.bfloat16: torch_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
  17. if torch_weights[k].requires_grad: torch_weights[k] = torch_weights[k].detach()
  18. np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
  19. print(f"compared {len(tg_weights)} weights")
  20. class TestTorchLoad(unittest.TestCase):
  21. # pytorch pkl format
  22. def test_load_enet(self): compare_weights_both("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
  23. # pytorch zip format
  24. def test_load_enet_alt(self): compare_weights_both("https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth")
  25. # pytorch zip format
  26. def test_load_convnext(self): compare_weights_both('https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth')
  27. @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
  28. def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
  29. # pytorch tar format
  30. def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
  31. test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth"
  32. #test_size = test_fn.stat().st_size
  33. test_size = 1024*1024*1024*2
  34. def _test_bitcasted(t: Tensor, dt: DType, expected):
  35. np.testing.assert_allclose(t.bitcast(dt).numpy(), expected)
  36. # sudo su -c 'sync; echo 1 > /proc/sys/vm/drop_caches' && python3 test/unit/test_disk_tensor.py TestRawDiskBuffer.test_readinto_read_speed
  37. class TestRawDiskBuffer(unittest.TestCase):
  38. @unittest.skipIf(not test_fn.exists(), "download LLaMA weights for read in speed tests")
  39. def test_readinto_read_speed(self):
  40. tst = np.empty(test_size, np.uint8)
  41. with open(test_fn, "rb") as f:
  42. with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
  43. f.readinto(tst)
  44. def test_bitcasts_on_disk(self):
  45. _, tmp = tempfile.mkstemp()
  46. # ground truth = https://evanw.github.io/float-toy/
  47. t = Tensor.empty((128, 128), dtype=dtypes.uint8, device=f"disk:{tmp}") # uint8
  48. # all zeroes
  49. _test_bitcasted(t, dtypes.float16, 0.0)
  50. _test_bitcasted(t, dtypes.uint16, 0)
  51. _test_bitcasted(t, dtypes.float32, 0.0)
  52. _test_bitcasted(t, dtypes.uint32, 0)
  53. # pi in float16 stored via int16
  54. t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize()
  55. _test_bitcasted(t, dtypes.float16, 3.140625)
  56. _test_bitcasted(t, dtypes.float32, 50.064727)
  57. _test_bitcasted(t, dtypes.uint16, 0x4248)
  58. _test_bitcasted(t, dtypes.uint32, 0x42484248)
  59. # pi in float32 stored via float32
  60. t.bitcast(dtypes.float32).assign(Tensor.full((128, 32), 3.1415927, dtype=dtypes.float32)).realize()
  61. _test_bitcasted(t, dtypes.float32, 3.1415927)
  62. _test_bitcasted(t, dtypes.uint32, 0x40490FDB)
  63. # doesn't suport normal cast
  64. with self.assertRaises(RuntimeError):
  65. Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16)
  66. # Those two should be moved to test_dtype.py:test_shape_change_bitcast after bitcast works on non-disk
  67. with self.assertRaises(RuntimeError):
  68. # should fail because 3 int8 is 3 bytes but float16 is two and 3 isn't a multiple of 2
  69. Tensor.empty((3,), dtype=dtypes.int8, device=f"DISK:{tmp}").bitcast(dtypes.float16)
  70. with self.assertRaises(RuntimeError):
  71. # should fail because backprop through bitcast is undefined
  72. Tensor.empty((4,), dtype=dtypes.int8, requires_grad=True, device=f"DISK:{tmp}").bitcast(dtypes.float16)
  73. pathlib.Path(tmp).unlink()
  74. @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
  75. class TestSafetensors(unittest.TestCase):
  76. def test_real_safetensors(self):
  77. import torch
  78. from safetensors.torch import save_file
  79. torch.manual_seed(1337)
  80. tensors = {
  81. "weight1": torch.randn((16, 16)),
  82. "weight2": torch.arange(0, 17, dtype=torch.uint8),
  83. "weight3": torch.arange(0, 17, dtype=torch.int32).reshape(17,1,1),
  84. "weight4": torch.arange(0, 2, dtype=torch.uint8),
  85. }
  86. save_file(tensors, temp("real.safetensors"))
  87. ret = safe_load(temp("real.safetensors"))
  88. for k,v in tensors.items(): np.testing.assert_array_equal(ret[k].numpy(), v.numpy())
  89. safe_save(ret, temp("real.safetensors_alt"))
  90. with open(temp("real.safetensors"), "rb") as f:
  91. with open(temp("real.safetensors_alt"), "rb") as g:
  92. assert f.read() == g.read()
  93. ret2 = safe_load(temp("real.safetensors_alt"))
  94. for k,v in tensors.items(): np.testing.assert_array_equal(ret2[k].numpy(), v.numpy())
  95. def test_real_safetensors_open(self):
  96. fn = temp("real_safe")
  97. state_dict = {"tmp": Tensor.rand(10,10)}
  98. safe_save(state_dict, fn)
  99. import os
  100. assert os.path.getsize(fn) == 8+0x40+(10*10*4)
  101. from safetensors import safe_open
  102. with safe_open(fn, framework="pt", device="cpu") as f:
  103. assert sorted(f.keys()) == sorted(state_dict.keys())
  104. for k in f.keys():
  105. np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
  106. def test_efficientnet_safetensors(self):
  107. from extra.models.efficientnet import EfficientNet
  108. model = EfficientNet(0)
  109. state_dict = get_state_dict(model)
  110. safe_save(state_dict, temp("eff0"))
  111. state_dict_loaded = safe_load(temp("eff0"))
  112. assert sorted(state_dict_loaded.keys()) == sorted(state_dict.keys())
  113. for k,v in state_dict.items():
  114. np.testing.assert_array_equal(v.numpy(), state_dict_loaded[k].numpy())
  115. # load with the real safetensors
  116. from safetensors import safe_open
  117. with safe_open(temp("eff0"), framework="pt", device="cpu") as f:
  118. assert sorted(f.keys()) == sorted(state_dict.keys())
  119. for k in f.keys():
  120. np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
  121. def test_huggingface_enet_safetensors(self):
  122. # test a real file
  123. fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
  124. state_dict = safe_load(fn)
  125. assert len(state_dict.keys()) == 244
  126. assert 'blocks.2.2.se.conv_reduce.weight' in state_dict
  127. assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570
  128. assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570
  129. def test_metadata(self):
  130. metadata = {"hello": "world"}
  131. safe_save({}, temp('metadata.safetensors'), metadata)
  132. import struct
  133. with open(temp('metadata.safetensors'), 'rb') as f:
  134. dat = f.read()
  135. sz = struct.unpack(">Q", dat[0:8])[0]
  136. import json
  137. assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
  138. def test_save_all_dtypes(self):
  139. for dtype in dtypes.fields().values():
  140. if dtype in [dtypes.bfloat16]: continue # not supported in numpy
  141. path = temp(f"ones.{dtype}.safetensors")
  142. ones = Tensor(np.random.rand(10,10), dtype=dtype)
  143. safe_save(get_state_dict(ones), path)
  144. np.testing.assert_equal(ones.numpy(), list(safe_load(path).values())[0].numpy())
  145. def test_load_supported_types(self):
  146. import torch
  147. from safetensors.torch import save_file
  148. from safetensors.numpy import save_file as np_save_file
  149. torch.manual_seed(1337)
  150. tensors = {
  151. "weight_F16": torch.randn((2, 2), dtype=torch.float16),
  152. "weight_F32": torch.randn((2, 2), dtype=torch.float32),
  153. "weight_U8": torch.tensor([1, 2, 3], dtype=torch.uint8),
  154. "weight_I8": torch.tensor([-1, 2, 3], dtype=torch.int8),
  155. "weight_I32": torch.tensor([-1, 2, 3], dtype=torch.int32),
  156. "weight_I64": torch.tensor([-1, 2, 3], dtype=torch.int64),
  157. "weight_F64": torch.randn((2, 2), dtype=torch.double),
  158. "weight_BOOL": torch.tensor([True, False], dtype=torch.bool),
  159. "weight_I16": torch.tensor([127, 64], dtype=torch.short),
  160. "weight_BF16": torch.randn((2, 2), dtype=torch.bfloat16),
  161. }
  162. save_file(tensors, temp("dtypes.safetensors"))
  163. loaded = safe_load(temp("dtypes.safetensors"))
  164. for k,v in loaded.items():
  165. if v.dtype != dtypes.bfloat16:
  166. assert v.numpy().dtype == tensors[k].numpy().dtype
  167. np.testing.assert_allclose(v.numpy(), tensors[k].numpy())
  168. # pytorch does not support U16, U32, and U64 dtypes.
  169. tensors = {
  170. "weight_U16": np.array([1, 2, 3], dtype=np.uint16),
  171. "weight_U32": np.array([1, 2, 3], dtype=np.uint32),
  172. "weight_U64": np.array([1, 2, 3], dtype=np.uint64),
  173. }
  174. np_save_file(tensors, temp("dtypes.safetensors"))
  175. loaded = safe_load(temp("dtypes.safetensors"))
  176. for k,v in loaded.items():
  177. assert v.numpy().dtype == tensors[k].dtype
  178. np.testing.assert_allclose(v.numpy(), tensors[k])
  179. def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None):
  180. if tinygrad_fxn is None: tinygrad_fxn = np_fxn
  181. pathlib.Path(temp(fn)).unlink(missing_ok=True)
  182. tinygrad_tensor = Tensor(data, device="CLANG").to(f"disk:{temp(fn)}")
  183. numpy_arr = np.array(data)
  184. tinygrad_fxn(tinygrad_tensor)
  185. np_fxn(numpy_arr)
  186. np.testing.assert_allclose(tinygrad_tensor.numpy(), numpy_arr)
  187. class TestDiskTensor(unittest.TestCase):
  188. def test_empty(self):
  189. pathlib.Path(temp("dt_empty")).unlink(missing_ok=True)
  190. Tensor.empty(100, 100, device=f"disk:{temp('dt_empty')}")
  191. def test_simple_read(self):
  192. fn = pathlib.Path(temp("dt_simple_read"))
  193. fn.unlink(missing_ok=True)
  194. fn.write_bytes(bytes(range(256)))
  195. t = Tensor.empty(16, 16, device=f"disk:{temp('dt_simple_read')}", dtype=dtypes.uint8)
  196. out = t[1].to(Device.DEFAULT).tolist()
  197. assert out == list(range(16, 32))
  198. def test_simple_read_bitcast(self):
  199. fn = pathlib.Path(temp("dt_simple_read_bitcast"))
  200. fn.unlink(missing_ok=True)
  201. fn.write_bytes(bytes(range(256))*2)
  202. t = Tensor.empty(16, 16*2, device=f"disk:{temp('dt_simple_read_bitcast')}", dtype=dtypes.uint8)
  203. out = t[1].bitcast(dtypes.uint16).to(Device.DEFAULT).tolist()
  204. tout = [(x//256, x%256) for x in out]
  205. assert tout == list([(x+1,x) for x in range(32,64,2)])
  206. def test_simple_read_bitcast_alt(self):
  207. fn = pathlib.Path(temp("dt_simple_read_bitcast_alt"))
  208. fn.unlink(missing_ok=True)
  209. fn.write_bytes(bytes(range(256))*2)
  210. t = Tensor.empty(16, 16*2, device=f"disk:{temp('dt_simple_read_bitcast_alt')}", dtype=dtypes.uint8)
  211. out = t.bitcast(dtypes.uint16)[1].to(Device.DEFAULT).tolist()
  212. tout = [(x//256, x%256) for x in out]
  213. assert tout == list([(x+1,x) for x in range(32,64,2)])
  214. def test_write_ones(self):
  215. pathlib.Path(temp("dt_write_ones")).unlink(missing_ok=True)
  216. out = Tensor.ones(10, 10, device="CLANG").contiguous()
  217. outdisk = out.to(f"disk:{temp('dt_write_ones')}")
  218. print(outdisk)
  219. outdisk.realize()
  220. del out, outdisk
  221. import struct
  222. # test file
  223. with open(temp("dt_write_ones"), "rb") as f:
  224. assert f.read() == struct.pack('<f', 1.0) * 100 == b"\x00\x00\x80\x3F" * 100
  225. # test load alt
  226. reloaded = Tensor.empty(10, 10, device=f"disk:{temp('dt_write_ones')}")
  227. np.testing.assert_almost_equal(reloaded.numpy(), np.ones((10, 10)))
  228. def test_assign_slice(self):
  229. def assign(x,s,y): x[s] = y
  230. helper_test_disk_tensor("dt_assign_slice_1", [0,1,2,3], lambda x: assign(x, slice(0,2), [13, 12]))
  231. helper_test_disk_tensor("dt_assign_slice_2", [[0,1,2,3],[4,5,6,7]], lambda x: assign(x, slice(0,1), [[13, 12, 11, 10]]))
  232. def test_reshape(self):
  233. helper_test_disk_tensor("dt_reshape_1", [1,2,3,4,5], lambda x: x.reshape((1,5)))
  234. helper_test_disk_tensor("dt_reshape_2", [1,2,3,4], lambda x: x.reshape((2,2)))
  235. def test_assign_to_different_dtype(self):
  236. # NOTE: this is similar to Y_train in fetch_cifar
  237. t = Tensor.empty(10, device=f'disk:{temp("dt_assign_to_different_dtype")}', dtype=dtypes.int64)
  238. for i in range(5):
  239. data = np.array([3, 3])
  240. idx = 2 * i
  241. t[idx:idx+2].assign(data)
  242. np.testing.assert_array_equal(t.numpy(), np.array([3] * 10))
  243. def test_bitcast(self):
  244. with open(temp('dt_bitcast'), "wb") as f: f.write(bytes(range(10,20)))
  245. t = Tensor.empty(5, dtype=dtypes.int16, device=f"disk:{temp('dt_bitcast')}")
  246. ret = t.to("CLANG").bitcast(dtypes.uint16) + 1
  247. assert ret.tolist() == [2827, 3341, 3855, 4369, 4883]
  248. def test_bitcast_view(self):
  249. with open(temp('dt_bitcast_view'), "wb") as f: f.write(bytes(range(10, 24)))
  250. t = Tensor.empty(3, dtype=dtypes.uint, device=f"disk:{temp('dt_bitcast_view')}").shrink([(0, 2)])
  251. ret = t.bitcast(dtypes.uint16).to("CLANG") + 1
  252. assert ret.tolist() == [2827, 3341, 3855, 4369]
  253. def test_bf16_disk_write_read(self):
  254. t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
  255. t.to(f"disk:{temp('dt_bf16_disk_write_read_f32')}").realize()
  256. # hack to "cast" f32 -> bf16
  257. with open(temp('dt_bf16_disk_write_read_f32'), "rb") as f: dat = f.read()
  258. adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
  259. with open(temp('dt_bf16_disk_write_read_bf16'), "wb") as f: f.write(adat)
  260. t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('dt_bf16_disk_write_read_bf16')}")
  261. ct = t.llvm_bf16_cast(dtypes.float)
  262. assert ct.numpy().tolist() == [9984., -1, -1000, -9984, 20]
  263. def test_copy_from_disk(self):
  264. fn = pathlib.Path(temp("dt_copy_from_disk"))
  265. fn.unlink(missing_ok=True)
  266. fn.write_bytes(bytes(range(256))*1024)
  267. t = Tensor.empty(256*1024, device=f"disk:{temp('dt_copy_from_disk')}", dtype=dtypes.uint8)
  268. on_dev = t.to(Device.DEFAULT).realize()
  269. np.testing.assert_equal(on_dev.numpy(), t.numpy())
  270. def test_copy_from_disk_offset(self):
  271. fn = pathlib.Path(temp("dt_copy_from_disk_offset"))
  272. fn.unlink(missing_ok=True)
  273. fn.write_bytes(bytes(range(256))*1024)
  274. for off in [314, 991, 2048, 4096]:
  275. t = Tensor.empty(256*1024, device=f"disk:{temp('dt_copy_from_disk_offset')}", dtype=dtypes.uint8)[off:]
  276. on_dev = t.to(Device.DEFAULT).realize()
  277. np.testing.assert_equal(on_dev.numpy(), t.numpy())
  278. def test_copy_from_disk_huge(self):
  279. if CI and not hasattr(Device["DISK"], 'io_uring'): self.skipTest("slow on ci without iouring")
  280. fn = pathlib.Path(temp("dt_copy_from_disk_huge"))
  281. fn.unlink(missing_ok=True)
  282. fn.write_bytes(bytes(range(256))*1024*256)
  283. for off in [0, 551]:
  284. t = Tensor.empty(256*1024*256, device=f"disk:{temp('dt_copy_from_disk_huge')}", dtype=dtypes.uint8)[off:]
  285. on_dev = t.to(Device.DEFAULT).realize()
  286. np.testing.assert_equal(on_dev.numpy(), t.numpy())
  287. if __name__ == "__main__":
  288. unittest.main()