imagenet.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # for imagenet download prepare.sh and run it
  2. import glob, random, json, math
  3. import numpy as np
  4. from PIL import Image
  5. import functools, pathlib
  6. from tinygrad.helpers import diskcache, getenv
  7. @functools.lru_cache(None)
  8. def get_imagenet_categories():
  9. ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
  10. return {v[0]: int(k) for k,v in ci.items()}
  11. if getenv("MNISTMOCK"):
  12. BASEDIR = pathlib.Path(__file__).parent / "mnist"
  13. @functools.lru_cache(None)
  14. def get_train_files():
  15. if not BASEDIR.exists():
  16. from extra.datasets.fake_imagenet_from_mnist import create_fake_mnist_imagenet
  17. create_fake_mnist_imagenet(BASEDIR)
  18. if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}")
  19. return files
  20. else:
  21. BASEDIR = pathlib.Path(__file__).parent / "imagenet"
  22. @diskcache
  23. def get_train_files():
  24. if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}")
  25. return files
  26. @functools.lru_cache(None)
  27. def get_val_files():
  28. if not (files:=glob.glob(p:=str(BASEDIR / "val/*/*"))): raise FileNotFoundError(f"No validation files in {p}")
  29. return files
  30. def image_resize(img, size, interpolation):
  31. w, h = img.size
  32. w_new = int((w / h) * size) if w > h else size
  33. h_new = int((h / w) * size) if h > w else size
  34. return img.resize([w_new, h_new], interpolation)
  35. def rand_flip(img):
  36. if random.random() < 0.5:
  37. img = np.flip(img, axis=1).copy()
  38. return img
  39. def center_crop(img):
  40. rescale = min(img.size) / 256
  41. crop_left = (img.width - 224 * rescale) / 2.0
  42. crop_top = (img.height - 224 * rescale) / 2.0
  43. img = img.resize((224, 224), Image.BILINEAR, box=(crop_left, crop_top, crop_left + 224 * rescale, crop_top + 224 * rescale))
  44. return img
  45. # we don't use supplied imagenet bounding boxes, so scale min is just min_object_covered
  46. # https://github.com/tensorflow/tensorflow/blob/e193d8ea7776ef5c6f5d769b6fb9c070213e737a/tensorflow/core/kernels/image/sample_distorted_bounding_box_op.cc
  47. def random_resized_crop(img, size, scale=(0.10, 1.0), ratio=(3/4, 4/3)):
  48. w, h = img.size
  49. area = w * h
  50. # Crop
  51. random_solution_found = False
  52. for _ in range(100):
  53. aspect_ratio = random.uniform(ratio[0], ratio[1])
  54. max_scale = min(min(w * aspect_ratio / h, h / aspect_ratio / w), scale[1])
  55. target_area = area * random.uniform(scale[0], max_scale)
  56. w_new = int(round(math.sqrt(target_area * aspect_ratio)))
  57. h_new = int(round(math.sqrt(target_area / aspect_ratio)))
  58. if 0 < w_new <= w and 0 < h_new <= h:
  59. crop_left = random.randint(0, w - w_new)
  60. crop_top = random.randint(0, h - h_new)
  61. img = img.crop((crop_left, crop_top, crop_left + w_new, crop_top + h_new))
  62. random_solution_found = True
  63. break
  64. if not random_solution_found:
  65. # Center crop
  66. img = center_crop(img)
  67. else:
  68. # Resize
  69. img = img.resize([size, size], Image.BILINEAR)
  70. return img
  71. def preprocess_train(img):
  72. img = random_resized_crop(img, 224)
  73. img = rand_flip(np.array(img))
  74. return img