test_winograd.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import unittest
  2. from tinygrad import Tensor, GlobalCounters
  3. from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
  4. from tinygrad.ops import MetaOps
  5. from tinygrad.codegen.kernel import Kernel
  6. from tinygrad.engine.schedule import create_schedule
  7. class TestWinograd(unittest.TestCase):
  8. def setUp(self):
  9. self.old = WINO.value
  10. WINO.value = 1
  11. def tearDown(self):
  12. WINO.value = self.old
  13. def test_speed(self):
  14. x = Tensor.empty(1,4,9,9)
  15. w = Tensor.empty(4,4,3,3)
  16. with Timing("running conv: "):
  17. out = Tensor.conv2d(x, w)
  18. with Timing("scheduling: "):
  19. sched = create_schedule([out.lazydata])
  20. for i,s in enumerate(sched):
  21. if s.ast.op is not MetaOps.KERNEL: continue
  22. ops = s.ast.lazyops
  23. with Timing(f"linearize {i} with {len(ops):4d} ops: "):
  24. l = Kernel(s.ast)
  25. l.hand_coded_optimizations()
  26. l.linearize()
  27. assert len(l.sts) <= 256 # just the current value to prevent regression
  28. if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
  29. for st in l.sts:
  30. assert len(st.views) <= 2, "too many views in winograd"
  31. if DEBUG >= 3:
  32. print(f"{len(st.views):3d} views")
  33. for v in st.views: print(v)
  34. def test_profile(self):
  35. x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
  36. with Profiling(enabled=not CI, sort='time'):
  37. out = Tensor.conv2d(x,w).realize()
  38. out.numpy()
  39. def test_four_kernels(self):
  40. x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
  41. GlobalCounters.reset()
  42. out = Tensor.conv2d(x,w).realize()
  43. assert GlobalCounters.kernel_count == 4
  44. out.numpy()
  45. @unittest.skipIf(getenv("PTX"), "winograd uses too much in PTX")
  46. def test_counters(self):
  47. IC, OC, X, Y = 4,4,9,9
  48. #OC, IC, X, Y = 512, 256, 8, 8
  49. x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
  50. GlobalCounters.reset()
  51. Tensor.conv2d(x,w).realize()
  52. ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
  53. WINO.value = 0
  54. GlobalCounters.reset()
  55. Tensor.conv2d(x,w).realize()
  56. ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem
  57. ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
  58. print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
  59. print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
  60. self.assertLess(ops_ratio, 2.6) # TODO: there's issues with factorization now
  61. self.assertLess(mem_ratio, 10)
  62. if __name__ == '__main__':
  63. unittest.main(verbosity=2)