test_net_speed.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #!/usr/bin/env python
  2. import time
  3. import unittest
  4. import torch
  5. from tinygrad import Tensor, Device
  6. from tinygrad.helpers import Profiling, CI
  7. @unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
  8. class TestConvSpeed(unittest.TestCase):
  9. def test_mnist(self):
  10. # https://keras.io/examples/vision/mnist_convnet/
  11. conv = 3
  12. inter_chan, out_chan = 32, 64
  13. # ****** torch baseline *******
  14. torch.backends.mkldnn.enabled = False
  15. conv = 3
  16. inter_chan, out_chan = 32, 64
  17. c1 = torch.randn(inter_chan,1,conv,conv, requires_grad=True)
  18. c2 = torch.randn(out_chan,inter_chan,conv,conv, requires_grad=True)
  19. l1 = torch.randn(out_chan*5*5, 10, requires_grad=True)
  20. c2d = torch.nn.functional.conv2d
  21. mp = torch.nn.MaxPool2d((2,2))
  22. lsm = torch.nn.LogSoftmax(dim=1)
  23. cnt = 5
  24. fpt, bpt = 0.0, 0.0
  25. for i in range(cnt):
  26. et0 = time.time()
  27. x = torch.randn(128, 1, 28, 28, requires_grad=True)
  28. x = mp(c2d(x,c1).relu())
  29. x = mp(c2d(x,c2).relu())
  30. x = x.reshape(x.shape[0], -1)
  31. out = lsm(x.matmul(l1))
  32. out = out.mean()
  33. et1 = time.time()
  34. out.backward()
  35. et2 = time.time()
  36. fpt += (et1-et0)
  37. bpt += (et2-et1)
  38. fpt_baseline = (fpt*1000/cnt)
  39. bpt_baseline = (bpt*1000/cnt)
  40. print("torch forward pass: %.3f ms" % fpt_baseline)
  41. print("torch backward pass: %.3f ms" % bpt_baseline)
  42. # ****** tinygrad compare *******
  43. c1 = Tensor(c1.detach().numpy(), requires_grad=True)
  44. c2 = Tensor(c2.detach().numpy(), requires_grad=True)
  45. l1 = Tensor(l1.detach().numpy(), requires_grad=True)
  46. cnt = 5
  47. fpt, bpt = 0.0, 0.0
  48. for i in range(1+cnt):
  49. et0 = time.time()
  50. x = Tensor.randn(128, 1, 28, 28)
  51. x = x.conv2d(c1).relu().avg_pool2d()
  52. x = x.conv2d(c2).relu().max_pool2d()
  53. x = x.reshape(shape=(x.shape[0], -1))
  54. out = x.dot(l1).log_softmax()
  55. out = out.mean()
  56. out.realize()
  57. et1 = time.time()
  58. out.backward()
  59. [x.grad.realize() for x in [c1, c2, l1]]
  60. et2 = time.time()
  61. if i == 0:
  62. pr = Profiling(sort='time', frac=0.2)
  63. pr.__enter__()
  64. else:
  65. fpt += (et1-et0)
  66. bpt += (et2-et1)
  67. pr.__exit__()
  68. fpt = (fpt*1000/cnt)
  69. bpt = (bpt*1000/cnt)
  70. print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline))
  71. print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline))
  72. if __name__ == '__main__':
  73. unittest.main()