| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import unittest
- from tinygrad import Tensor, GlobalCounters
- from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
- from tinygrad.ops import MetaOps
- from tinygrad.codegen.kernel import Kernel
- from tinygrad.engine.schedule import create_schedule
- class TestWinograd(unittest.TestCase):
- def setUp(self):
- self.old = WINO.value
- WINO.value = 1
- def tearDown(self):
- WINO.value = self.old
- def test_speed(self):
- x = Tensor.empty(1,4,9,9)
- w = Tensor.empty(4,4,3,3)
- with Timing("running conv: "):
- out = Tensor.conv2d(x, w)
- with Timing("scheduling: "):
- sched = create_schedule([out.lazydata])
- for i,s in enumerate(sched):
- if s.ast.op is not MetaOps.KERNEL: continue
- ops = s.ast.lazyops
- with Timing(f"linearize {i} with {len(ops):4d} ops: "):
- l = Kernel(s.ast)
- l.hand_coded_optimizations()
- l.linearize()
- assert len(l.sts) <= 256 # just the current value to prevent regression
- if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views")
- for st in l.sts:
- assert len(st.views) <= 2, "too many views in winograd"
- if DEBUG >= 3:
- print(f"{len(st.views):3d} views")
- for v in st.views: print(v)
- def test_profile(self):
- x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
- with Profiling(enabled=not CI, sort='time'):
- out = Tensor.conv2d(x,w).realize()
- out.numpy()
- def test_four_kernels(self):
- x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
- GlobalCounters.reset()
- out = Tensor.conv2d(x,w).realize()
- assert GlobalCounters.kernel_count == 4
- out.numpy()
- @unittest.skipIf(getenv("PTX"), "winograd uses too much in PTX")
- def test_counters(self):
- IC, OC, X, Y = 4,4,9,9
- #OC, IC, X, Y = 512, 256, 8, 8
- x,w = Tensor.rand(1,IC,Y,X).realize(), Tensor.rand(OC,IC,3,3).realize()
- GlobalCounters.reset()
- Tensor.conv2d(x,w).realize()
- ops_wino, mem_wino = GlobalCounters.global_ops, GlobalCounters.global_mem
- WINO.value = 0
- GlobalCounters.reset()
- Tensor.conv2d(x,w).realize()
- ops_normal, mem_normal = GlobalCounters.global_ops, GlobalCounters.global_mem
- ops_ratio, mem_ratio = ops_wino/ops_normal, mem_wino/mem_normal
- print(f"ops: normal {ops_normal:9d} wino {ops_wino:9d} ratio {ops_ratio:.2f}")
- print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
- self.assertLess(ops_ratio, 2.6) # TODO: there's issues with factorization now
- self.assertLess(mem_ratio, 10)
- if __name__ == '__main__':
- unittest.main(verbosity=2)
|