preprocess_imagenet.py 661 B

123456789101112131415161718192021
  1. from tinygrad import Tensor, dtypes
  2. from extra.datasets.imagenet import iterate, get_val_files
  3. if __name__ == "__main__":
  4. #sz = len(get_val_files())
  5. sz = 32*100
  6. X,Y = None, None
  7. idx = 0
  8. for x,y in iterate(shuffle=False):
  9. print(x.shape, y.shape, x.dtype, y.dtype)
  10. assert x.shape[0] == y.shape[0]
  11. bs = x.shape[0]
  12. if X is None:
  13. X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8)
  14. Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64)
  15. print(X.shape, Y.shape)
  16. X[idx:idx+bs].assign(x)
  17. Y[idx:idx+bs].assign(y)
  18. idx += bs
  19. if idx >= sz: break