1
0

multitensor.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import numpy as np
  2. from tinygrad import Tensor, Device, GlobalCounters
  3. from tinygrad.helpers import Timing
  4. d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2"
  5. N = 256
  6. FLOPS = N*N*N*2
  7. # LazyBuffer should make three fields lists: self.st (all must have the same shape), self.realized, and self.device
  8. def explicit_shard_W_axis_1(X, W):
  9. Xs = [X.to(d0), X.to(d1)]
  10. Ws = [W[:, :N//2].to(d0), W[:, N//2:].to(d1)] # TODO: these shouldn't make copies on the original device
  11. # pad them to form the correct size
  12. Ws = [Ws[0].pad((None, (0,N//2))), Ws[1].pad((None, (N//2,0)))]
  13. for x in Xs: assert x.shape == X.shape
  14. for w in Ws: assert w.shape == W.shape
  15. # TODO: it shouldn't be faster with these realize
  16. for x in Xs+Ws: x.realize()
  17. def lm(x:Tensor, w:Tensor):
  18. # these are movement ops on the local device
  19. x = x.reshape(N, 1, N).expand(N, N, N)
  20. w = w.T.reshape(1, N, N).expand(N, N, N)
  21. m = x*w
  22. assert m.lazydata.st.views[0].mask is not None
  23. ret = m.sum(2)
  24. return ret
  25. #Os = [lm(Xs[0], Ws[0]), lm(Xs[1], Ws[1])]
  26. Os = [Xs[0] @ Ws[0], Xs[1] @ Ws[1]]
  27. for x in Os: x.realize()
  28. return Os[0].to(Device.DEFAULT) + Os[1].to(Device.DEFAULT)
  29. #return Tensor.cat(*[x.to(Device.DEFAULT) for x in Os], dim=1) # TODO: someday we can remove this copy too
  30. def matmul(X, W):
  31. return explicit_shard_W_axis_1(X, W)
  32. #return X@W
  33. if __name__ == "__main__":
  34. with Timing("init devices: "):
  35. Device[d0], Device[d1]
  36. with Timing("create tensors: "):
  37. X = Tensor.kaiming_uniform(N, N).realize()
  38. W = Tensor.kaiming_uniform(N, N).realize()
  39. #with Timing("warmup: "):
  40. # O = matmul(X, W).numpy()
  41. GlobalCounters.reset()
  42. print("******** multiply start")
  43. with Timing("******** multiply done: ", lambda x: f" {FLOPS/x:.2f} GFLOPS"):
  44. O = matmul(X, W).realize()
  45. Device[Device.DEFAULT].synchronize()
  46. with Timing("testing: "):
  47. val = X.numpy() @ W.numpy()
  48. np.testing.assert_allclose(val, O.numpy(), atol=1e-5)