export.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #!/usr/bin/env python3
  2. import os
  3. if "NOOPT" not in os.environ: os.environ["NOOPT"] = "1"
  4. from tinygrad import Device, nn, Tensor, dtypes, Variable
  5. Device.DEFAULT = "CLANG"
  6. from train_gpt2 import GPT, GPTConfig
  7. from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name
  8. from tinygrad.engine.schedule import create_schedule, memory_planner
  9. from tinygrad.engine.realize import get_kernel, run_schedule
  10. from tinygrad.ops import BufferOps, MetaOps
  11. TIMING = getenv("TIMING")
  12. if __name__ == "__main__":
  13. model = GPT(GPTConfig(n_layer=getenv("NLAYER", 12), n_head=12, n_embd=768))
  14. #model.load_pretrained()
  15. for p in nn.state.get_parameters(model): p.replace(Tensor.empty(p.shape, dtype=p.dtype)) # fake load pretrained
  16. seen = set()
  17. #early_sched = create_schedule([x.lazydata for x in nn.state.get_parameters(model)], seen)
  18. #print(f"built model {len(early_sched)}")
  19. #B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
  20. B, T = 4, 64
  21. Tensor.training = True
  22. optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4)
  23. warmup_count = getenv("WARMUP", 3)
  24. for i in range(warmup_count): # TODO: why does it take three and not two to stablize
  25. if i == warmup_count-1: GRAPH.value = getenv("LATEGRAPH")
  26. GlobalCounters.reset()
  27. X = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T)
  28. Y = Tensor.empty(4, 64, dtype=dtypes.int).reshape(B, T)
  29. _, loss = model(X, Y)
  30. optimizer.zero_grad()
  31. if getenv("BACKWARD", 1):
  32. loss.backward()
  33. tensors = optimizer.schedule_step()
  34. else:
  35. tensors = []
  36. sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors], seen)
  37. print(f"calls {i}:", len(sched))
  38. #run_schedule(sched[:])
  39. del seen # free the LazyBuffers
  40. sched = memory_planner(sched)
  41. ast_dedup = dedup([si.ast for si in sched if si.ast[0].op is BufferOps.STORE])
  42. srcs = {}
  43. for ast in ast_dedup:
  44. k = get_kernel(Device["CLANG"].renderer, ast)
  45. k.linearize()
  46. src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
  47. srcs[ast] = (k.name, src)
  48. print("functions:", len(srcs))
  49. used_buffers = dedup(flatten([si.bufs for si in sched]))
  50. numbered_bufs = {x:i for i,x in enumerate(used_buffers)}
  51. print("buffers:", len(numbered_bufs))
  52. state_dict = nn.state.get_state_dict(model)
  53. state_dict.update({'X': X, 'Y': Y, 'loss': loss})
  54. grad_state_dict = {}
  55. for k,v in state_dict.items():
  56. if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}")
  57. if v.grad is not None: grad_state_dict['grad_'+k] = v.grad
  58. state_dict.update(grad_state_dict)
  59. state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr})
  60. inverse_state_dict = {v:k for k,v in state_dict.items()}
  61. for p,m,v in zip(optimizer.params, optimizer.m, optimizer.v):
  62. nm = inverse_state_dict[p]
  63. state_dict["adam_m_"+nm] = m
  64. state_dict["adam_v_"+nm] = v
  65. named_buffers = {v.lazydata.base.buffer:k.replace(".", "_") for k,v in state_dict.items()}
  66. c_code = ["#include <stdlib.h>", "#include <tgmath.h>", "#include <stdbool.h>"]
  67. if TIMING: c_code += ["#include <stdio.h>", "#include <time.h>"]
  68. c_code += [x[1].replace(" restrict ", " ")+"\n" for x in srcs.values()]
  69. premain = ["int main() {"]
  70. if TIMING:
  71. premain += [" struct timespec tm0; clock_gettime(CLOCK_MONOTONIC, &tm0);"]
  72. lst = 0
  73. main = []
  74. all_bufs = []
  75. for i,si in enumerate(sched):
  76. bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.bufs]
  77. all_bufs += bufs
  78. if si.ast[0].op is not BufferOps.STORE:
  79. print(f"// {si.ast[0].op}", bufs)
  80. else:
  81. print(f"{srcs[si.ast][0]}({', '.join([x[0] for x in bufs])})")
  82. main.append(f" {to_function_name(srcs[si.ast][0])}({', '.join([x[0] for x in bufs])});")
  83. if TIMING:
  84. main.append(f" struct timespec tm{i+1}; clock_gettime(CLOCK_MONOTONIC, &tm{i+1});")
  85. main.append(f" printf(\"%10.2f ms + %7.2f ms @ {to_function_name(srcs[si.ast][0])}\\n\"," +\
  86. f"((tm{i+1}.tv_sec-tm{0}.tv_sec) + (tm{i+1}.tv_nsec-tm{0}.tv_nsec) / 1e9) * 1e3," +\
  87. f"((tm{i+1}.tv_sec-tm{lst}.tv_sec) + (tm{i+1}.tv_nsec-tm{lst}.tv_nsec) / 1e9) * 1e3);")
  88. lst = i+1
  89. #call = f"{srcs[si.ast][0]}({', '.join(bufs)})"
  90. #call += " "*(80-ansilen(call))
  91. #print(f"{call} // {i+1}")
  92. #print(srcs[si.ast][1])
  93. main.append("}")
  94. mallocs = [f" {b.dtype.name}* {n} = ({b.dtype.name}*)malloc({b.nbytes});" for n,b in dedup(all_bufs)]
  95. with open("out.c", "w") as f: f.write('\n'.join(c_code+premain+mallocs+main))