__init__.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import os, gzip, tarfile, pickle
  2. import numpy as np
  3. from tinygrad import Tensor, dtypes
  4. from tinygrad.helpers import fetch
  5. def fetch_mnist(tensors=False):
  6. parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
  7. BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https
  8. X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
  9. Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:].astype(np.int8)
  10. X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
  11. Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:].astype(np.int8)
  12. if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
  13. else: return X_train, Y_train, X_test, Y_test
  14. cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
  15. cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
  16. def fetch_cifar():
  17. X_train = Tensor.empty(50000, 3*32*32, device=f'disk:/tmp/cifar_train_x', dtype=dtypes.uint8)
  18. Y_train = Tensor.empty(50000, device=f'disk:/tmp/cifar_train_y', dtype=dtypes.int64)
  19. X_test = Tensor.empty(10000, 3*32*32, device=f'disk:/tmp/cifar_test_x', dtype=dtypes.uint8)
  20. Y_test = Tensor.empty(10000, device=f'disk:/tmp/cifar_test_y', dtype=dtypes.int64)
  21. if not os.path.isfile("/tmp/cifar_extracted"):
  22. def _load_disk_tensor(X, Y, db_list):
  23. idx = 0
  24. for db in db_list:
  25. x, y = db[b'data'], np.array(db[b'labels'])
  26. assert x.shape[0] == y.shape[0]
  27. X[idx:idx+x.shape[0]].assign(x)
  28. Y[idx:idx+x.shape[0]].assign(y)
  29. idx += x.shape[0]
  30. assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
  31. print("downloading and extracting CIFAR...")
  32. fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
  33. tt = tarfile.open(fn, mode='r:gz')
  34. _load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)])
  35. _load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")])
  36. open("/tmp/cifar_extracted", "wb").close()
  37. return X_train, Y_train, X_test, Y_test