augment.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import numpy as np
  2. from PIL import Image
  3. from pathlib import Path
  4. import sys
  5. cwd = Path.cwd()
  6. sys.path.append(cwd.as_posix())
  7. sys.path.append((cwd / 'test').as_posix())
  8. from extra.datasets import fetch_mnist
  9. from tqdm import trange
  10. def augment_img(X, rotate=10, px=3):
  11. Xaug = np.zeros_like(X)
  12. for i in trange(len(X)):
  13. im = Image.fromarray(X[i])
  14. im = im.rotate(np.random.randint(-rotate,rotate), resample=Image.BICUBIC)
  15. w, h = X.shape[1:]
  16. #upper left, lower left, lower right, upper right
  17. quad = np.random.randint(-px,px,size=(8)) + np.array([0,0,0,h,w,h,w,0])
  18. im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC)
  19. Xaug[i] = im
  20. return Xaug
  21. if __name__ == "__main__":
  22. import matplotlib.pyplot as plt
  23. X_train, Y_train, X_test, Y_test = fetch_mnist()
  24. X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
  25. X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
  26. X = np.vstack([X_train[:1]]*10+[X_train[1:2]]*10)
  27. fig, a = plt.subplots(2,len(X))
  28. Xaug = augment_img(X)
  29. for i in range(len(X)):
  30. a[0][i].imshow(X[i], cmap='gray')
  31. a[1][i].imshow(Xaug[i],cmap='gray')
  32. a[0][i].axis('off')
  33. a[1][i].axis('off')
  34. plt.show()
  35. #create some nice gifs for doc?!
  36. for i in range(10):
  37. im = Image.fromarray(X_train[7353+i])
  38. im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))]
  39. im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0)