graph_batchnorm.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import unittest
  2. from tinygrad.nn.state import get_parameters
  3. from tinygrad.tensor import Tensor
  4. from tinygrad.nn import Conv2d, BatchNorm2d, optim
  5. def model_step(lm):
  6. with Tensor.train():
  7. x = Tensor.ones(8,12,128,256, requires_grad=False)
  8. optimizer = optim.SGD(get_parameters(lm), lr=0.001)
  9. loss = lm.forward(x).sum()
  10. optimizer.zero_grad()
  11. loss.backward()
  12. del x,loss
  13. optimizer.step()
  14. class TestBatchnorm(unittest.TestCase):
  15. def test_conv(self):
  16. class LilModel:
  17. def __init__(self):
  18. self.c = Conv2d(12, 32, 3, padding=1, bias=False)
  19. def forward(self, x):
  20. return self.c(x).relu()
  21. lm = LilModel()
  22. model_step(lm)
  23. def test_two_conv(self):
  24. class LilModel:
  25. def __init__(self):
  26. self.c = Conv2d(12, 32, 3, padding=1, bias=False)
  27. self.c2 = Conv2d(32, 32, 3, padding=1, bias=False)
  28. def forward(self, x):
  29. return self.c2(self.c(x)).relu()
  30. lm = LilModel()
  31. model_step(lm)
  32. def test_two_conv_bn(self):
  33. class LilModel:
  34. def __init__(self):
  35. self.c = Conv2d(12, 24, 3, padding=1, bias=False)
  36. self.bn = BatchNorm2d(24, track_running_stats=False)
  37. self.c2 = Conv2d(24, 32, 3, padding=1, bias=False)
  38. self.bn2 = BatchNorm2d(32, track_running_stats=False)
  39. def forward(self, x):
  40. x = self.bn(self.c(x)).relu()
  41. return self.bn2(self.c2(x)).relu()
  42. lm = LilModel()
  43. model_step(lm)
  44. def test_conv_bn(self):
  45. class LilModel:
  46. def __init__(self):
  47. self.c = Conv2d(12, 32, 3, padding=1, bias=False)
  48. self.bn = BatchNorm2d(32, track_running_stats=False)
  49. def forward(self, x):
  50. return self.bn(self.c(x)).relu()
  51. lm = LilModel()
  52. model_step(lm)
  53. if __name__ == '__main__':
  54. unittest.main()