beautiful_cartpole.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from typing import Tuple
  2. import time
  3. from tinygrad import Tensor, TinyJit, nn
  4. import gymnasium as gym
  5. from tinygrad.helpers import trange
  6. import numpy as np # TODO: remove numpy import
  7. ENVIRONMENT_NAME = 'CartPole-v1'
  8. #ENVIRONMENT_NAME = 'LunarLander-v2'
  9. #import examples.rl.lightupbutton
  10. #ENVIRONMENT_NAME = 'PressTheLightUpButton-v0'
  11. # *** hyperparameters ***
  12. # https://github.com/llSourcell/Unity_ML_Agents/blob/master/docs/best-practices-ppo.md
  13. BATCH_SIZE = 256
  14. ENTROPY_SCALE = 0.0005
  15. REPLAY_BUFFER_SIZE = 2000
  16. PPO_EPSILON = 0.2
  17. HIDDEN_UNITS = 32
  18. LEARNING_RATE = 1e-2
  19. TRAIN_STEPS = 5
  20. EPISODES = 40
  21. DISCOUNT_FACTOR = 0.99
  22. class ActorCritic:
  23. def __init__(self, in_features, out_features, hidden_state=HIDDEN_UNITS):
  24. self.l1 = nn.Linear(in_features, hidden_state)
  25. self.l2 = nn.Linear(hidden_state, out_features)
  26. self.c1 = nn.Linear(in_features, hidden_state)
  27. self.c2 = nn.Linear(hidden_state, 1)
  28. def __call__(self, obs:Tensor) -> Tuple[Tensor, Tensor]:
  29. x = self.l1(obs).tanh()
  30. act = self.l2(x).log_softmax()
  31. x = self.c1(obs).relu()
  32. return act, self.c2(x)
  33. def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
  34. (obs, _), terminated, truncated = test_env.reset(), False, False
  35. total_rew = 0.0
  36. while not terminated and not truncated:
  37. act = model(Tensor(obs))[0].argmax().item()
  38. obs, rew, terminated, truncated, _ = test_env.step(act)
  39. total_rew += float(rew)
  40. return total_rew
  41. if __name__ == "__main__":
  42. env = gym.make(ENVIRONMENT_NAME)
  43. model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
  44. opt = nn.optim.Adam(nn.state.get_parameters(model), lr=LEARNING_RATE)
  45. @TinyJit
  46. def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
  47. with Tensor.train():
  48. log_dist, value = model(x)
  49. action_mask = (selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)).float()
  50. # get real advantage using the value function
  51. advantage = reward.reshape(-1, 1) - value
  52. masked_advantage = action_mask * advantage.detach()
  53. # PPO
  54. ratios = (log_dist - old_log_dist).exp()
  55. unclipped_ratio = masked_advantage * ratios
  56. clipped_ratio = masked_advantage * ratios.clip(1-PPO_EPSILON, 1+PPO_EPSILON)
  57. action_loss = -unclipped_ratio.minimum(clipped_ratio).sum(-1).mean()
  58. entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity
  59. critic_loss = advantage.square().mean()
  60. opt.zero_grad()
  61. (action_loss + entropy_loss*ENTROPY_SCALE + critic_loss).backward()
  62. opt.step()
  63. return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
  64. @TinyJit
  65. def get_action(obs:Tensor) -> Tensor:
  66. # TODO: with no_grad
  67. Tensor.no_grad = True
  68. ret = model(obs)[0].exp().multinomial().realize()
  69. Tensor.no_grad = False
  70. return ret
  71. st, steps = time.perf_counter(), 0
  72. Xn, An, Rn = [], [], []
  73. for episode_number in (t:=trange(EPISODES)):
  74. get_action.reset() # NOTE: if you don't reset the jit here it captures the wrong model on the first run through
  75. obs:np.ndarray = env.reset()[0]
  76. rews, terminated, truncated = [], False, False
  77. # NOTE: we don't want to early stop since then the rewards are wrong for the last episode
  78. while not terminated and not truncated:
  79. # pick actions
  80. # TODO: what's the temperature here?
  81. act = get_action(Tensor(obs)).item()
  82. # save this state action pair
  83. # TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment
  84. Xn.append(np.copy(obs))
  85. An.append(act)
  86. obs, rew, terminated, truncated, _ = env.step(act)
  87. rews.append(float(rew))
  88. steps += len(rews)
  89. # reward to go
  90. # TODO: move this into tinygrad
  91. discounts = np.power(DISCOUNT_FACTOR, np.arange(len(rews)))
  92. Rn += [np.sum(rews[i:] * discounts[:len(rews)-i]) for i in range(len(rews))]
  93. Xn, An, Rn = Xn[-REPLAY_BUFFER_SIZE:], An[-REPLAY_BUFFER_SIZE:], Rn[-REPLAY_BUFFER_SIZE:]
  94. X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
  95. # TODO: make this work
  96. #vsz = Variable("sz", 1, REPLAY_BUFFER_SIZE-1).bind(len(Xn))
  97. #X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz)
  98. old_log_dist = model(X)[0].detach() # TODO: could save these instead of recomputing
  99. for i in range(TRAIN_STEPS):
  100. samples = Tensor.randint(BATCH_SIZE, high=X.shape[0]).realize() # TODO: remove the need for this
  101. # TODO: is this recompiling based on the shape?
  102. action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
  103. t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.3f} entropy_loss: {entropy_loss.item():7.3f} critic_loss: {critic_loss.item():8.3f} reward: {sum(rews):6.2f}")
  104. test_rew = evaluate(model, gym.make(ENVIRONMENT_NAME, render_mode='human'))
  105. print(f"test reward: {test_rew}")