handcode_opt.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from typing import List
  2. from extra.models.resnet import ResNet50
  3. from examples.mlperf.helpers import get_mlperf_bert_model
  4. from tinygrad import Tensor, Device, dtypes, nn
  5. from tinygrad.codegen.kernel import Kernel
  6. from tinygrad.device import Compiled
  7. from tinygrad.engine.graph import print_tree
  8. from tinygrad.engine.schedule import create_schedule
  9. from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
  10. from tinygrad.helpers import DEBUG, ansilen, getenv
  11. from tinygrad.ops import MetaOps, get_lazyop_info
  12. from tinygrad.shape.symbolic import sym_infer
  13. def get_sched_resnet():
  14. mdl = ResNet50()
  15. optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
  16. BS = getenv("BS", 64)
  17. # run model twice to get only what changes, these are the kernels of the model
  18. seen = set()
  19. for _ in range(2):
  20. out = mdl(Tensor.empty(BS, 3, 224, 224))
  21. targets = [out.lazydata]
  22. if getenv("BACKWARD"):
  23. optim.zero_grad()
  24. out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
  25. targets += [x.lazydata for x in optim.schedule_step()]
  26. sched = create_schedule(targets, seen)
  27. print(f"schedule length {len(sched)}")
  28. return sched
  29. def get_sched_bert():
  30. mdl = get_mlperf_bert_model()
  31. optim = nn.optim.LAMB(nn.state.get_parameters(mdl))
  32. # fake data
  33. BS = getenv("BS", 2)
  34. input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
  35. segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
  36. attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float)
  37. masked_positions = Tensor.empty((BS, 76), dtype=dtypes.float32)
  38. masked_lm_ids = Tensor.empty((BS, 76), dtype=dtypes.float32)
  39. masked_lm_weights = Tensor.empty((BS, 76), dtype=dtypes.float32)
  40. next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
  41. # run model twice to get only what changes, these are the kernels of the model
  42. seen = set()
  43. for _ in range(2):
  44. lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids)
  45. targets = [lm_logits.lazydata, seq_relationship_logits.lazydata]
  46. if getenv("BACKWARD"):
  47. optim.zero_grad()
  48. loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
  49. # ignore grad norm and loss scaler for now
  50. loss.backward()
  51. targets += [x.lazydata for x in optim.schedule_step()]
  52. sched = create_schedule(targets, seen)
  53. print(f"schedule length {len(sched)}")
  54. return sched
  55. if __name__ == "__main__":
  56. if getenv("HALF", 1):
  57. dtypes.default_float = dtypes.half
  58. # the device we are optimizing for
  59. device: Compiled = Device[Device.DEFAULT]
  60. if getenv("BACKWARD"): Tensor.training = True
  61. print(f"optimizing for {Device.DEFAULT}")
  62. sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
  63. sched = [x for x in sched if x.ast.op is MetaOps.KERNEL]
  64. # focus on one kernel
  65. if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
  66. # work with the schedule
  67. total_tm = 0
  68. running_gflops = 0
  69. usage = {}
  70. for i,si in enumerate(sched):
  71. ops = get_lazyop_info(si.ast.src[0]).flops
  72. if DEBUG >= 2:
  73. print_tree(si.ast)
  74. rawbufs = bufs_from_lin(Kernel(si.ast))
  75. # "linearize" the op into uops in different ways
  76. lins:List[Kernel] = []
  77. # always try hand coded opt
  78. lin = Kernel(si.ast, opts=device.renderer)
  79. lin.hand_coded_optimizations()
  80. lins.append(lin)
  81. # maybe try tensor cores
  82. lin = Kernel(si.ast, opts=device.renderer)
  83. if lin.apply_tensor_cores():
  84. lins.append(lin)
  85. # try a beam search
  86. if beam:=getenv("BEAM"):
  87. lin = Kernel(si.ast, opts=device.renderer)
  88. lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
  89. lins.append(lin)
  90. # benchmark the programs
  91. choices = []
  92. for lin in lins:
  93. tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
  94. gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
  95. choices.append((tm, gflops, lin.linearize()))
  96. # print all kernels
  97. if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
  98. tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
  99. total_tm += tm
  100. running_gflops += gflops * tm
  101. if (key := str([str(m) for m in si.metadata] if si.metadata is not None else None)) not in usage: usage[key] = (0, 0)
  102. usage[key] = (usage[key][0] + tm, usage[key][1] + 1)
  103. print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[str(m) for m in si.metadata] if si.metadata is not None else ''}")
  104. print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
  105. print("usage:")
  106. for k in sorted(usage, key=lambda x: -usage[x][0])[:10]:
  107. print(f"{usage[k][0]*1000:.2f} ms: {k} ({usage[k][1]} times)")