helpers.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. from __future__ import annotations
  2. import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes, sys
  3. import itertools, urllib.request, subprocess, shutil, math, json, contextvars
  4. from dataclasses import dataclass
  5. from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
  6. if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
  7. from typing_extensions import TypeGuard
  8. from tinygrad.shape.shapetracker import sint
  9. T = TypeVar("T")
  10. U = TypeVar("U")
  11. # NOTE: it returns int 1 if x is empty regardless of the type of x
  12. def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
  13. # NOTE: helpers is not allowed to import from anything else in tinygrad
  14. OSX = platform.system() == "Darwin"
  15. CI = os.getenv("CI", "") != ""
  16. def dedup(x:Iterable[T]): return list(dict.fromkeys(x)) # retains list order
  17. def argfix(*x):
  18. if x and x[0].__class__ in (tuple, list):
  19. if len(x) != 1: raise ValueError(f"bad arg {x}")
  20. return tuple(x[0])
  21. return x
  22. def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
  23. def all_same(items:List[T]): return all(x == items[0] for x in items)
  24. def all_int(t: Sequence[Any]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
  25. def colored(st, color:Optional[str], background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line # noqa: E501
  26. def colorize_float(x: float): return colored(f"{x:7.2f}x", 'green' if x < 0.75 else 'red' if x > 1.15 else 'yellow')
  27. def ansistrip(s:str): return re.sub('\x1b\\[(K|.*?m)', '', s)
  28. def ansilen(s:str): return len(ansistrip(s))
  29. def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
  30. def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
  31. def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
  32. def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
  33. def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
  34. def round_up(num, amt:int): return (num+amt-1)//amt * amt
  35. def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
  36. assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
  37. return {k:v for d in ds for k,v in d.items()}
  38. def partition(lst:List[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
  39. a:List[T] = []
  40. b:List[T] = []
  41. for s in lst: (a if fxn(s) else b).append(s)
  42. return a,b
  43. def unwrap(x:Optional[T]) -> T:
  44. assert x is not None
  45. return x
  46. def unwrap2(x:Tuple[T,Any]) -> T:
  47. ret, err = x
  48. assert err is None, str(err)
  49. return ret
  50. def get_child(obj, key):
  51. for k in key.split('.'):
  52. if k.isnumeric(): obj = obj[int(k)]
  53. elif isinstance(obj, dict): obj = obj[k]
  54. else: obj = getattr(obj, k)
  55. return obj
  56. def get_shape(x) -> Tuple[int, ...]:
  57. if not isinstance(x, (list, tuple)): return ()
  58. subs = [get_shape(xi) for xi in x]
  59. if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
  60. return (len(subs),) + (subs[0] if subs else ())
  61. # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
  62. def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
  63. acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
  64. try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
  65. except ValueError: return None
  66. return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]
  67. @functools.lru_cache(maxsize=None)
  68. def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
  69. @functools.lru_cache(maxsize=None)
  70. def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
  71. def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
  72. class GraphException(Exception): pass
  73. class Context(contextlib.ContextDecorator):
  74. stack: ClassVar[List[dict[str, int]]] = [{}]
  75. def __init__(self, **kwargs): self.kwargs = kwargs
  76. def __enter__(self):
  77. Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
  78. for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
  79. Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
  80. def __exit__(self, *args):
  81. for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
  82. class ContextVar:
  83. _cache: ClassVar[Dict[str, ContextVar]] = {}
  84. value: int
  85. key: str
  86. def __new__(cls, key, default_value):
  87. if key in ContextVar._cache: return ContextVar._cache[key]
  88. instance = ContextVar._cache[key] = super().__new__(cls)
  89. instance.value, instance.key = getenv(key, default_value), key
  90. return instance
  91. def __bool__(self): return bool(self.value)
  92. def __ge__(self, x): return self.value >= x
  93. def __gt__(self, x): return self.value > x
  94. def __lt__(self, x): return self.value < x
  95. DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
  96. WINO, THREEFRY, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
  97. GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
  98. MULTIOUTPUT, PROFILE, TRANSCENDENTAL = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("TRANSCENDENTAL", 1)
  99. USE_TC, TC_OPT = ContextVar("TC", 1), ContextVar("TC_OPT", 0)
  100. FUSE_AS_ONE_KERNEL = ContextVar("FUSE_AS_ONE_KERNEL", 0)
  101. @dataclass(frozen=True)
  102. class Metadata:
  103. name: str
  104. caller: str
  105. backward: bool = False
  106. def __hash__(self): return hash(self.name)
  107. def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
  108. def __str__(self): return self.name + (" bw" if self.backward else "")
  109. _METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
  110. # **************** global state Counters ****************
  111. class GlobalCounters:
  112. global_ops: ClassVar[int] = 0
  113. global_mem: ClassVar[int] = 0
  114. time_sum_s: ClassVar[float] = 0.0
  115. kernel_count: ClassVar[int] = 0
  116. mem_used: ClassVar[int] = 0 # NOTE: this is not reset
  117. @staticmethod
  118. def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
  119. # **************** timer and profiler ****************
  120. class Timing(contextlib.ContextDecorator):
  121. def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
  122. def __enter__(self): self.st = time.perf_counter_ns()
  123. def __exit__(self, *exc):
  124. self.et = time.perf_counter_ns() - self.st
  125. if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
  126. def _format_fcn(fcn): return f"{fcn[0]}:{fcn[1]}:{fcn[2]}"
  127. class Profiling(contextlib.ContextDecorator):
  128. def __init__(self, enabled=True, sort='cumtime', frac=0.2, fn=None, ts=1):
  129. self.enabled, self.sort, self.frac, self.fn, self.time_scale = enabled, sort, frac, fn, 1e3/ts
  130. def __enter__(self):
  131. self.pr = cProfile.Profile()
  132. if self.enabled: self.pr.enable()
  133. def __exit__(self, *exc):
  134. if self.enabled:
  135. self.pr.disable()
  136. if self.fn: self.pr.dump_stats(self.fn)
  137. stats = pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort)
  138. for fcn in stats.fcn_list[0:int(len(stats.fcn_list)*self.frac)]: # type: ignore[attr-defined]
  139. (_primitive_calls, num_calls, tottime, cumtime, callers) = stats.stats[fcn] # type: ignore[attr-defined]
  140. scallers = sorted(callers.items(), key=lambda x: -x[1][2])
  141. print(f"n:{num_calls:8d} tm:{tottime*self.time_scale:7.2f}ms tot:{cumtime*self.time_scale:7.2f}ms",
  142. colored(_format_fcn(fcn), "yellow") + " "*(50-len(_format_fcn(fcn))),
  143. colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if len(scallers) else '')
  144. class ProfileLogger:
  145. writers: int = 0
  146. mjson: List[Dict] = []
  147. actors: Dict[str, int] = {}
  148. subactors: Dict[Tuple[str, str], int] = {}
  149. path = getenv("PROFILE_OUTPUT_FILE", temp("tinygrad_profile.json"))
  150. def __init__(self): self.events, ProfileLogger.writers = [], ProfileLogger.writers + 1
  151. def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor)]
  152. def __del__(self):
  153. for name,st,et,actor_name,subactor_name in self.events:
  154. if actor_name not in self.actors:
  155. self.actors[actor_name] = (pid:=len(self.actors))
  156. self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
  157. if (subactor_key:=(actor_name,subactor_name)) not in self.subactors:
  158. self.subactors[subactor_key] = (tid:=len(self.subactors))
  159. self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
  160. self.mjson.append({"name": name, "ph": "X", "pid": self.actors[actor_name], "tid": self.subactors.get(subactor_key, -1), "ts":st, "dur":et-st})
  161. ProfileLogger.writers -= 1
  162. if ProfileLogger.writers == 0 and len(self.mjson) > 0:
  163. with open(self.path, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
  164. print(f"Saved profile to {self.path}. Use https://ui.perfetto.dev/ to open it.")
  165. # *** universal database cache ***
  166. _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
  167. CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
  168. CACHELEVEL = getenv("CACHELEVEL", 2)
  169. VERSION = 16
  170. _db_connection = None
  171. def db_connection():
  172. global _db_connection
  173. if _db_connection is None:
  174. os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
  175. _db_connection = sqlite3.connect(CACHEDB, timeout=60, isolation_level="IMMEDIATE")
  176. _db_connection.execute("PRAGMA journal_mode=WAL")
  177. if DEBUG >= 7: _db_connection.set_trace_callback(print)
  178. return _db_connection
  179. def diskcache_clear():
  180. cur = db_connection().cursor()
  181. drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
  182. cur.executescript("\n".join([s[0] for s in drop_tables]))
  183. def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
  184. if CACHELEVEL == 0: return None
  185. if isinstance(key, (str,int)): key = {"key": key}
  186. conn = db_connection()
  187. cur = conn.cursor()
  188. try:
  189. res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
  190. except sqlite3.OperationalError:
  191. return None # table doesn't exist
  192. if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
  193. return None
  194. _db_tables = set()
  195. def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
  196. if CACHELEVEL == 0: return val
  197. if isinstance(key, (str,int)): key = {"key": key}
  198. conn = db_connection()
  199. cur = conn.cursor()
  200. if table not in _db_tables:
  201. TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
  202. ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
  203. cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
  204. _db_tables.add(table)
  205. cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501
  206. conn.commit()
  207. cur.close()
  208. return val
  209. def diskcache(func):
  210. def wrapper(*args, **kwargs) -> bytes:
  211. table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
  212. if (ret:=diskcache_get(table, key)): return ret
  213. return diskcache_put(table, key, func(*args, **kwargs))
  214. return wrapper
  215. # *** http support ***
  216. def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None,
  217. allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
  218. if url.startswith(("/", ".")): return pathlib.Path(url)
  219. if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
  220. else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest())
  221. if not fp.is_file() or not allow_caching:
  222. with urllib.request.urlopen(url, timeout=10) as r:
  223. assert r.status == 200
  224. total_length = int(r.headers.get('content-length', 0))
  225. progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
  226. (path := fp.parent).mkdir(parents=True, exist_ok=True)
  227. with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
  228. while chunk := r.read(16384): progress_bar.update(f.write(chunk))
  229. f.close()
  230. progress_bar.update(close=True)
  231. if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
  232. pathlib.Path(f.name).rename(fp)
  233. return fp
  234. # *** Exec helpers
  235. def cpu_time_execution(cb, enable):
  236. if enable: st = time.perf_counter()
  237. cb()
  238. if enable: return time.perf_counter()-st
  239. def cpu_objdump(lib):
  240. with tempfile.NamedTemporaryFile(delete=True) as f:
  241. pathlib.Path(f.name).write_bytes(lib)
  242. print(subprocess.check_output(['objdump', '-d', f.name]).decode('utf-8'))
  243. # *** ctypes helpers
  244. # TODO: make this work with read only memoryviews (if possible)
  245. def from_mv(mv:memoryview, to_type=ctypes.c_char):
  246. return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
  247. def to_mv(ptr, sz) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
  248. def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
  249. def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
  250. @functools.lru_cache(maxsize=None)
  251. def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
  252. class CStruct(ctypes.Structure):
  253. _pack_, _fields_ = 1, fields
  254. return CStruct
  255. def init_c_var(ctypes_var, creat_cb): return (creat_cb(ctypes_var), ctypes_var)[1]
  256. def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))
  257. # *** tqdm
  258. class tqdm:
  259. def __init__(self, iterable=None, desc:str='', disable:bool=False, unit:str='it', unit_scale=False, total:Optional[int]=None, rate:int=100):
  260. self.iter, self.desc, self.dis, self.unit, self.unit_scale, self.rate = iterable, f"{desc}: " if desc else "", disable, unit, unit_scale, rate
  261. self.st, self.i, self.n, self.skip, self.t = time.perf_counter(), -1, 0, 1, getattr(iterable, "__len__", lambda:0)() if total is None else total
  262. self.update(0)
  263. def __iter__(self):
  264. for item in self.iter:
  265. yield item
  266. self.update(1)
  267. self.update(close=True)
  268. def set_description(self, desc:str): self.desc = f"{desc}: " if desc else ""
  269. def update(self, n:int=0, close:bool=False):
  270. self.n, self.i = self.n+n, self.i+1
  271. if self.dis or (not close and self.i % self.skip != 0): return
  272. prog, dur, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
  273. if self.i/dur > self.rate and self.i: self.skip = max(int(self.i/dur)//self.rate,1)
  274. def fmt(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
  275. def fn(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
  276. unit_text = f'{fn(self.n)}{f"/{fn(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
  277. it_text = (fn(self.n/dur) if self.unit_scale else f"{self.n/dur:5.2f}") if self.n else "?"
  278. tm = f'{fmt(dur)}<{fmt(dur/prog-dur) if self.n else "?"}' if self.t else fmt(dur)
  279. suf = f'{unit_text} [{tm}, {it_text}{self.unit}/s]'
  280. sz = max(ncols-len(self.desc)-5-2-len(suf), 1)
  281. bar = '\r' + self.desc + (f'{100*prog:3.0f}%|{("█"*int(num:=sz*prog)+" ▏▎▍▌▋▊▉"[int(8*num)%8].strip()).ljust(sz," ")}| ' if self.t else '') + suf
  282. print(bar[:ncols+1],flush=True,end='\n'*close,file=sys.stderr)
  283. class trange(tqdm):
  284. def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)