external_jit_failure.py 414 B

1234567891011121314151617
  1. from tinygrad import Tensor, TinyJit, Device
  2. import numpy as np
  3. GPUS = 4
  4. N = 128
  5. ds = tuple([Device.canonicalize(f"{Device.DEFAULT}:{i}") for i in range(GPUS)])
  6. t = Tensor.rand(N, N, N).shard(ds, 0)
  7. n = t.numpy()
  8. @TinyJit
  9. def allreduce(t:Tensor) -> Tensor:
  10. return t.sum(0) #.realize()
  11. for i in range(10):
  12. print(i)
  13. tn = allreduce(t).numpy()
  14. np.testing.assert_allclose(tn, n.sum(0), atol=1e-4, rtol=1e-4)