state.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import os, json, pathlib, zipfile, pickle, tarfile, struct
  2. from typing import Dict, Union, List, Optional, Any, Tuple
  3. from tinygrad.tensor import Tensor
  4. from tinygrad.dtype import dtypes
  5. from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
  6. from tinygrad.shape.view import strides_for_shape
  7. from tinygrad.multi import MultiLazyBuffer
  8. safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint,
  9. "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
  10. inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
  11. def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
  12. """
  13. Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
  14. """
  15. t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
  16. json_len = t[0:8].bitcast(dtypes.int64).item()
  17. return t, json_len, json.loads(t[8:8+json_len].numpy().tobytes())
  18. def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
  19. """
  20. Loads a .safetensor file from disk, returning the state_dict.
  21. ```python
  22. state_dict = nn.state.safe_load("test.safetensor")
  23. ```
  24. """
  25. t, json_len, metadata = safe_load_metadata(fn)
  26. ret = {}
  27. for k,v in metadata.items():
  28. if k == "__metadata__": continue
  29. dtype = safe_dtypes[v['dtype']]
  30. sz = (v['data_offsets'][1]-v['data_offsets'][0])
  31. ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
  32. return ret
  33. def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
  34. """
  35. Saves a state_dict to disk in a .safetensor file with optional metadata.
  36. ```python
  37. t = Tensor([1, 2, 3])
  38. nn.state.safe_save({'t':t}, "test.safetensor")
  39. ```
  40. """
  41. headers, offset = {}, 0
  42. if metadata: headers['__metadata__'] = metadata
  43. for k,v in tensors.items():
  44. headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
  45. offset += v.nbytes()
  46. j = json.dumps(headers, separators=(',', ':'))
  47. j += "\x20"*((8-len(j)%8)%8)
  48. pathlib.Path(fn).unlink(missing_ok=True)
  49. t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
  50. t[0:8].bitcast(dtypes.int64).assign([len(j)])
  51. t[8:8+len(j)].assign(list(j.encode('utf-8')))
  52. for k,v in safe_load(t).items(): v.assign(tensors[k])
  53. # state dict
  54. from collections import OrderedDict
  55. def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
  56. """
  57. Returns a state_dict of the object, with optional prefix.
  58. ```python exec="true" source="above" session="tensor" result="python"
  59. class Net:
  60. def __init__(self):
  61. self.l1 = nn.Linear(4, 5)
  62. self.l2 = nn.Linear(5, 6)
  63. net = Net()
  64. print(nn.state.get_state_dict(net).keys())
  65. ```
  66. """
  67. if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
  68. if hasattr(obj, '_asdict'): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
  69. if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
  70. if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
  71. state_dict = {}
  72. if isinstance(obj, (list, tuple)):
  73. for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
  74. elif isinstance(obj, dict):
  75. for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
  76. return state_dict
  77. def get_parameters(obj) -> List[Tensor]:
  78. """
  79. ```python exec="true" source="above" session="tensor" result="python"
  80. class Net:
  81. def __init__(self):
  82. self.l1 = nn.Linear(4, 5)
  83. self.l2 = nn.Linear(5, 6)
  84. net = Net()
  85. print(len(nn.state.get_parameters(net)))
  86. ```
  87. """
  88. return list(get_state_dict(obj).values())
  89. def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=True, consume=False) -> None:
  90. """
  91. Loads a state_dict into a model.
  92. ```python
  93. class Net:
  94. def __init__(self):
  95. self.l1 = nn.Linear(4, 5)
  96. self.l2 = nn.Linear(5, 6)
  97. net = Net()
  98. state_dict = nn.state.get_state_dict(net)
  99. nn.state.load_state_dict(net, state_dict)
  100. ```
  101. """
  102. start_mem_used = GlobalCounters.mem_used
  103. with Timing("loaded weights in ", lambda et_ns: f", {(GlobalCounters.mem_used-start_mem_used)/1e9:.2f} GB loaded at {(GlobalCounters.mem_used-start_mem_used)/et_ns:.2f} GB/s"): # noqa: E501
  104. model_state_dict = get_state_dict(model)
  105. if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
  106. print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())))
  107. for k,v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
  108. t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: "
  109. if k not in state_dict and not strict:
  110. if DEBUG >= 1: print(f"WARNING: not loading {k}")
  111. continue
  112. if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
  113. if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
  114. else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
  115. else: v.replace(state_dict[k].to(v.device)).realize()
  116. if consume: del state_dict[k]
  117. # torch support!
  118. def torch_load(fn:str) -> Dict[str, Tensor]:
  119. """
  120. Loads a torch .pth file from disk.
  121. ```python
  122. state_dict = nn.state.torch_load("test.pth")
  123. ```
  124. """
  125. t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
  126. offsets: Dict[Union[str, int], int] = {}
  127. lens: Dict[Union[str, int], int] = {}
  128. def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
  129. #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
  130. lens[storage[2]] = storage[4] * storage[1].itemsize
  131. if storage[2] not in offsets: return None
  132. byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
  133. ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
  134. # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
  135. shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
  136. permute_indexes = [len(shape_strides)-1-y for y in argsort([x[1] for x in shape_strides])]
  137. if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
  138. intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
  139. assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
  140. if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
  141. assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
  142. # TODO: find a nice way to support all shapetracker on disktensors
  143. ret = ret.clang().reshape(intermediate_shape).permute(permute_indexes)
  144. return ret.reshape(size)
  145. class Parameter:
  146. def __setstate__(self, state): self.tensor = state[0]
  147. deserialized_objects: Dict[str, Any] = {}
  148. intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32,
  149. "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
  150. whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
  151. class Dummy: pass
  152. class TorchPickle(pickle.Unpickler):
  153. def find_class(self, module, name):
  154. module_root = module.split(".")[0]
  155. if module_root not in whitelist:
  156. if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
  157. return Dummy
  158. return intercept[name] if module_root == "torch" else super().find_class(module, name)
  159. def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
  160. if zipfile.is_zipfile(fn):
  161. myzip = zipfile.ZipFile(fn, 'r')
  162. base_name = myzip.namelist()[0].split('/', 1)[0]
  163. for n in myzip.namelist():
  164. if n.startswith(f'{base_name}/data/'):
  165. with myzip.open(n) as myfile:
  166. offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
  167. with myzip.open(f'{base_name}/data.pkl') as myfile:
  168. return TorchPickle(myfile).load()
  169. elif tarfile.is_tarfile(fn):
  170. with tarfile.open(fn, "r") as tar:
  171. storages_offset = tar.getmember('storages').offset_data
  172. f = unwrap(tar.extractfile('storages'))
  173. for i in range(TorchPickle(f).load()): # num_storages
  174. (key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
  175. offsets[key] = storages_offset + f.tell()
  176. f.seek(sz*storage_type.itemsize, 1)
  177. f = unwrap(tar.extractfile('tensors'))
  178. for _ in range(TorchPickle(f).load()): # num_tensors
  179. (key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
  180. size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
  181. storage_offset = struct.unpack('<q', f.read(8))[0]
  182. deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
  183. return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
  184. else:
  185. with open(fn, "rb") as f:
  186. pkl = TorchPickle(f)
  187. _, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
  188. for i in ids:
  189. offsets[i] = base_offset + 8
  190. base_offset += 8 + lens[i]
  191. f.seek(rwd)
  192. return TorchPickle(f).load()