| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- #!/usr/bin/env python
- import time
- import unittest
- import torch
- from tinygrad import Tensor, Device
- from tinygrad.helpers import Profiling, CI
- @unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
- class TestConvSpeed(unittest.TestCase):
- def test_mnist(self):
- # https://keras.io/examples/vision/mnist_convnet/
- conv = 3
- inter_chan, out_chan = 32, 64
- # ****** torch baseline *******
- torch.backends.mkldnn.enabled = False
- conv = 3
- inter_chan, out_chan = 32, 64
- c1 = torch.randn(inter_chan,1,conv,conv, requires_grad=True)
- c2 = torch.randn(out_chan,inter_chan,conv,conv, requires_grad=True)
- l1 = torch.randn(out_chan*5*5, 10, requires_grad=True)
- c2d = torch.nn.functional.conv2d
- mp = torch.nn.MaxPool2d((2,2))
- lsm = torch.nn.LogSoftmax(dim=1)
- cnt = 5
- fpt, bpt = 0.0, 0.0
- for i in range(cnt):
- et0 = time.time()
- x = torch.randn(128, 1, 28, 28, requires_grad=True)
- x = mp(c2d(x,c1).relu())
- x = mp(c2d(x,c2).relu())
- x = x.reshape(x.shape[0], -1)
- out = lsm(x.matmul(l1))
- out = out.mean()
- et1 = time.time()
- out.backward()
- et2 = time.time()
- fpt += (et1-et0)
- bpt += (et2-et1)
- fpt_baseline = (fpt*1000/cnt)
- bpt_baseline = (bpt*1000/cnt)
- print("torch forward pass: %.3f ms" % fpt_baseline)
- print("torch backward pass: %.3f ms" % bpt_baseline)
- # ****** tinygrad compare *******
- c1 = Tensor(c1.detach().numpy(), requires_grad=True)
- c2 = Tensor(c2.detach().numpy(), requires_grad=True)
- l1 = Tensor(l1.detach().numpy(), requires_grad=True)
- cnt = 5
- fpt, bpt = 0.0, 0.0
- for i in range(1+cnt):
- et0 = time.time()
- x = Tensor.randn(128, 1, 28, 28)
- x = x.conv2d(c1).relu().avg_pool2d()
- x = x.conv2d(c2).relu().max_pool2d()
- x = x.reshape(shape=(x.shape[0], -1))
- out = x.dot(l1).log_softmax()
- out = out.mean()
- out.realize()
- et1 = time.time()
- out.backward()
- [x.grad.realize() for x in [c1, c2, l1]]
- et2 = time.time()
- if i == 0:
- pr = Profiling(sort='time', frac=0.2)
- pr.__enter__()
- else:
- fpt += (et1-et0)
- bpt += (et2-et1)
- pr.__exit__()
- fpt = (fpt*1000/cnt)
- bpt = (bpt*1000/cnt)
- print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline))
- print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline))
- if __name__ == '__main__':
- unittest.main()
|