external_benchmark_resnet.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import functools
  2. import time
  3. import unittest
  4. from tinygrad import Tensor, TinyJit, GlobalCounters, Device
  5. from tinygrad.helpers import getenv, Context
  6. from tinygrad.nn.optim import SGD
  7. from tinygrad.nn.state import get_parameters
  8. from tinygrad.engine.realize import run_schedule
  9. from extra.models import resnet
  10. from examples.mlperf.initializers import Conv2dHeNormal, Linear
  11. from examples.hlb_cifar10 import UnsyncedBatchNorm
  12. # benchmark memory or kernel count: DEFAULT_FLOAT=HALF python test/external/external_benchmark_resnet.py
  13. # benchmark speed: BEAM=2 JITCNT=10 DEFAULT_FLOAT=HALF python test/external/external_benchmark_resnet.py
  14. # benchmark only one layer: BEAM=2 DEFAULT_FLOAT=HALF python test/external/external_benchmark_resnet.py BenchmarkResnetTrain.test_layer1_2
  15. # inspect: DEBUG=2 BEAM=2 DEFAULT_FLOAT=HALF python test/external/external_benchmark_resnet.py
  16. # inspect 1x1 convs: DEBUG=2 BEAM=2 CONV=2 DEFAULT_FLOAT=HALF python test/external/external_benchmark_resnet.py
  17. # inspect 3x3 convs: DEBUG=2 BEAM=2 CONV=2 DEFAULT_FLOAT=HALF python test/external/external_benchmark_resnet.py
  18. # inspect 3x3 convs with batchnorm: DEBUG=2 BEAM=2 CONV=2 BN=1 DEFAULT_FLOAT=HALF python test/external/external_benchmark_resnet.py
  19. # etc
  20. # use ASSIGN=0 to disable batchnorm/optimizer assigns
  21. # memory will be slightly high with JITCNT > 1
  22. bs = getenv("BS", 64)
  23. class BenchmarkResnetTrain(unittest.TestCase):
  24. def _get_layer(self, layer_i, slice_i):
  25. # isolate to conv, with or without BN
  26. conv = getenv("CONV", 0)
  27. bn = getenv("BN", 0)
  28. if not hasattr(self, 'model'):
  29. resnet.Conv2d = Conv2dHeNormal
  30. resnet.Linear = Linear
  31. if not getenv("SYNCBN"): resnet.BatchNorm = functools.partial(UnsyncedBatchNorm, num_devices=1)
  32. self.model = resnet.ResNet50()
  33. self.layers = [self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]
  34. layer = self.layers[layer_i][slice_i]
  35. xy = 112 >> layer_i
  36. xy >>= (1 if slice_i > 0 or layer_i == 0 else 0) # layer 1 is preceded by maxpool2d
  37. name = f"layer{layer_i+1} slice{slice_i+1}"
  38. # get specific conv
  39. if conv:
  40. convs = [layer.conv1, layer.conv2, layer.conv3] + ([layer.downsample[0]] if layer.downsample else [])
  41. bns = [layer.bn1, layer.bn2, layer.bn3] + ([layer.downsample[1]] if layer.downsample else [])
  42. f = [convs[conv-1]]
  43. if bn: f.append(bns[conv-1])
  44. f.append(Tensor.relu)
  45. cin = f[0].in_channels
  46. if conv == 3: xy //= convs[1].stride
  47. return f"{name} conv{conv} x{str((bs, cin, xy, xy)):20s} k{str(f[0].weight.shape):20s}" + (" bn" if bn else ""), f, cin, xy
  48. cin = layer.conv1.in_channels
  49. return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy
  50. def _test_layer(self, name, layer, cin, xy):
  51. optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking
  52. with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t.detach().contiguous()) for t in get_parameters(optim)])
  53. JITCNT = getenv("JITCNT", 1)
  54. Tensor.training = True
  55. @TinyJit
  56. def step(x):
  57. optim.zero_grad()
  58. x.grad = None
  59. y = x.sequential(layer).contiguous().contiguous_backward()
  60. y.sum().backward()
  61. if getenv("ASSIGN", 1): sched, _ = Tensor.schedule_with_vars(y, x.grad, *optim.schedule_step())
  62. else: sched, _ = Tensor.schedule_with_vars(y, x.grad, *[t.grad for t in optim.params])
  63. for _ in range(JITCNT):
  64. run_schedule(list(sched))
  65. CNT = getenv("CNT", 5)
  66. best_tm = None
  67. flops, mem_used, mem, kernels = None, None, None, None
  68. for i in range(CNT):
  69. with Context(SAVE_SCHEDULE=0): x = Tensor.randn(bs, cin, xy, xy, requires_grad=True).realize()
  70. GlobalCounters.reset()
  71. st = time.perf_counter()
  72. step(x)
  73. Device[Device.DEFAULT].synchronize()
  74. et = time.perf_counter()
  75. flops = GlobalCounters.global_ops / JITCNT
  76. mem_used = GlobalCounters.mem_used # a little high with JITCNT > 1 fsr
  77. mem = GlobalCounters.global_mem / JITCNT
  78. if kernels is None: kernels = GlobalCounters.kernel_count // JITCNT
  79. tm = (et-st) / JITCNT
  80. if best_tm is None or tm < best_tm: best_tm = tm
  81. print(f"\r{name:38s}: {best_tm * 1000:>9.2f} ms, {flops / 10**12 / best_tm:>6.2f} tflops, {mem / 10**9 / best_tm:>5.0f} GB/s, "
  82. f"{mem_used / 10**9: 6.2f} GB used, {kernels:>5d} kernels")
  83. return best_tm, flops, mem, kernels
  84. def test_layer1_1(self): self._est(*self._test_layer(*self._get_layer(0, 0)), 1)
  85. def test_layer1_2(self): self._est(*self._test_layer(*self._get_layer(0, 1)), 2)
  86. def test_layer2_1(self): self._est(*self._test_layer(*self._get_layer(1, 0)), 1)
  87. def test_layer2_2(self): self._est(*self._test_layer(*self._get_layer(1, 1)), 3)
  88. def test_layer3_1(self): self._est(*self._test_layer(*self._get_layer(2, 0)), 1)
  89. def test_layer3_2(self): self._est(*self._test_layer(*self._get_layer(2, 1)), 5)
  90. def test_layer4_1(self): self._est(*self._test_layer(*self._get_layer(3, 0)), 1)
  91. def test_layer4_2(self): self._est(*self._test_layer(*self._get_layer(3, 1)), 2)
  92. est_tm, est_flops, est_mem, est_kernels = 0, 0, 0, 0
  93. @classmethod
  94. def _est(cls, tm, flops, mem, kernels, mult):
  95. cls.est_tm += tm * mult
  96. cls.est_flops += flops * mult
  97. cls.est_mem += mem * mult
  98. cls.est_kernels += kernels * mult
  99. @classmethod
  100. def tearDownClass(cls):
  101. print(f"\restimated step tm: {cls.est_tm * 1000.0:.2f} ms, {cls.est_flops / 10 ** 12 / cls.est_tm:.3f} tflops, "
  102. f"{cls.est_mem / 10 ** 9 / cls.est_tm:.2f} GB/s, {cls.est_kernels} kernels")
  103. if __name__ == '__main__':
  104. unittest.main()