graph.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os, atexit, functools, contextlib
  2. from collections import defaultdict
  3. from typing import List, Any, DefaultDict, Union
  4. from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, BufferOps, TernaryOps, LazyOp
  5. from tinygrad.device import Device
  6. from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
  7. from tinygrad.codegen.uops import UOps, UOp
  8. from tinygrad.codegen.uopgraph import UPat
  9. from tinygrad.shape.symbolic import NumNode
  10. from tinygrad.lazy import LazyBuffer
  11. with contextlib.suppress(ImportError): import networkx as nx
  12. # **** debugging and graphing ****
  13. if DEBUG >= 2:
  14. def print_globalcounters():
  15. if GlobalCounters.time_sum_s == 0: return
  16. print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
  17. f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
  18. atexit.register(print_globalcounters)
  19. def save_graph(G, fn, opt=""):
  20. print("saving", G, f"to {fn}.svg")
  21. nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
  22. os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
  23. G:Any = None
  24. def init_graph():
  25. global G
  26. if G is not None: return
  27. G = nx.DiGraph()
  28. atexit.register(functools.partial(save_graph, G, GRAPHPATH)) # -Gnslimit=100 can make it finish, but you won't like results
  29. counts: DefaultDict[type, int] = defaultdict(int)
  30. def nm(x):
  31. if not hasattr(x, 'node_id'):
  32. setattr(x, 'node_id', counts[type(x)])
  33. counts[type(x)] += 1
  34. return x.node_id
  35. def realized_lazybuffer(lb:'LazyBuffer', num):
  36. init_graph()
  37. G.nodes[nm(lb)]['style'] = '"filled,bold"'
  38. G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
  39. G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num}"'
  40. top_colors = {MetaOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
  41. TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
  42. def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
  43. init_graph()
  44. if lb.base.realized is None and lb.base.op is MetaOps.CONST: return
  45. if lb.base != lb:
  46. offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0]
  47. label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
  48. G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
  49. G.add_edge(nm(lb.base), nm(lb), color='#00000060')
  50. lb = lb.base
  51. if lb.realized is None:
  52. label_append = []
  53. for idx,x in enumerate(lb.srcs):
  54. if nm(x) not in G.nodes: log_lazybuffer(x)
  55. if x.base.realized is None and x.base.op is MetaOps.CONST:
  56. label_append.append(f"\nCONST{idx} {x.base.arg:g}")
  57. else:
  58. G.add_edge(nm(x), nm(lb), color='#a0a0a0')
  59. label = '"' + \
  60. (str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
  61. (f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {MetaOps.CONST, UnaryOps.CAST} else "") + \
  62. (f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + f'\n{lb.metadata}"'
  63. G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
  64. if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
  65. else:
  66. if nm(lb) not in G.nodes:
  67. # realized but unseen?
  68. G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
  69. def _tree(dag:Union[LazyOp, UOp, UPat], cycles, cnt):
  70. cnt[0] += 1
  71. src = dag.src if isinstance(dag.src, (list, tuple)) else [] if dag.src is None else [dag.src]
  72. if len(src) == 0: return [f"━━ {dag.op} {dag.arg}"]
  73. if (lid := id(dag)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
  74. return [f"━⬆︎ goto {cycles[id(dag)][0]}: {dag.op}"]
  75. cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
  76. lines = [f"━┳ {dag.op} {dag.arg}"]
  77. childs = [_tree(c, cycles, cnt) for c in src]
  78. for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
  79. return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
  80. def print_tree(dag:Union[LazyOp, UOp, UPat]): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(dag, {}, [-1]))]))
  81. def graph_uops(uops:List[UOp]):
  82. colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0",
  83. UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
  84. UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
  85. G = nx.DiGraph()
  86. for u in uops:
  87. if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
  88. G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg).replace(':', '')) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
  89. for v in u.src: G.add_edge(uops.index(v), uops.index(u))
  90. save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')