train_gpt2.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #!/usr/bin/env python3
  2. import os, math, time
  3. import numpy as np
  4. from tinygrad import Tensor, nn, fetch, Device, TinyJit, GlobalCounters
  5. from dataclasses import dataclass
  6. @dataclass
  7. class GPTConfig:
  8. block_size: int = 1024
  9. vocab_size: int = 50257
  10. n_layer: int = 12
  11. n_head: int = 12
  12. n_embd: int = 768
  13. class CausalSelfAttention:
  14. def __init__(self, config:GPTConfig):
  15. assert config.n_embd % config.n_head == 0
  16. # key, query, value projections for all heads, but in a batch
  17. self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
  18. # output projection
  19. self.c_proj = nn.Linear(config.n_embd, config.n_embd)
  20. # regularization
  21. self.n_head = config.n_head
  22. self.n_embd = config.n_embd
  23. # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
  24. self.bias = Tensor.ones(1, 1, config.block_size, config.block_size).tril()
  25. self.bias.requires_grad = False
  26. def __call__(self, x:Tensor):
  27. B, T, C = x.shape
  28. qkv = self.c_attn(x)
  29. q, k, v = qkv.split(self.n_embd, dim=2)
  30. k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  31. q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  32. v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
  33. # manual implementation of attention
  34. att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
  35. att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
  36. att = att.softmax()
  37. y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
  38. y = y.transpose(1, 2).view(B, T, C) # re-assemble all head outputs side by side
  39. # output projection
  40. y = self.c_proj(y)
  41. return y
  42. class MLP:
  43. def __init__(self, config:GPTConfig):
  44. self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
  45. self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
  46. def __call__(self, x:Tensor) -> Tensor:
  47. return self.c_proj(self.c_fc(x).gelu())
  48. class Block:
  49. def __init__(self, config:GPTConfig):
  50. self.ln_1 = nn.LayerNorm(config.n_embd)
  51. self.attn = CausalSelfAttention(config)
  52. self.ln_2 = nn.LayerNorm(config.n_embd)
  53. self.mlp = MLP(config)
  54. def __call__(self, x:Tensor):
  55. x = x + self.attn(self.ln_1(x))
  56. x = x + self.mlp(self.ln_2(x))
  57. return x
  58. class GPT:
  59. def __init__(self, config:GPTConfig):
  60. self.config = config
  61. self.wte = nn.Embedding(config.vocab_size, config.n_embd)
  62. self.wpe = nn.Embedding(config.block_size, config.n_embd)
  63. self.h = [Block(config) for _ in range(config.n_layer)]
  64. self.ln_f = nn.LayerNorm(config.n_embd)
  65. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  66. self.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
  67. def load_pretrained(self):
  68. weights = nn.state.torch_load(fetch(f'https://huggingface.co/gpt2/resolve/main/pytorch_model.bin'))
  69. transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
  70. for k in weights:
  71. if k.endswith(transposed):
  72. weights[k] = weights[k].to(Device.DEFAULT).T.contiguous()
  73. # lm head and wte are tied
  74. weights['lm_head.weight'] = weights['wte.weight']
  75. nn.state.load_state_dict(self, weights)
  76. def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
  77. for _ in range(max_new_tokens):
  78. idx_cond = idx if idx.shape[1] <= self.config.block_size else idx[:, -self.config.block_size:]
  79. logits, _ = self(idx_cond)
  80. logits = logits[:, -1, :] / temperature
  81. idx_next = logits.softmax().multinomial()
  82. idx = Tensor.cat(idx, idx_next, dim=1)
  83. return idx
  84. def __call__(self, idx:Tensor, targets=None):
  85. b, t = idx.shape
  86. pos = Tensor.arange(0, t)
  87. tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
  88. pos_emb = self.wpe(pos) # position embeddings of shape (t, n_embd)
  89. x = tok_emb + pos_emb
  90. x = self.ln_f(x.sequential(self.h))
  91. if targets is not None:
  92. logits = self.lm_head(x)
  93. loss = logits.sparse_categorical_crossentropy(targets)
  94. else:
  95. logits = self.lm_head(x[:, [-1], :])
  96. loss = None
  97. return logits, loss
  98. if __name__ == "__main__":
  99. import tiktoken, argparse
  100. parser = argparse.ArgumentParser()
  101. parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
  102. parser.add_argument("--batch_size", type=int, default=4, help="batch size")
  103. parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
  104. args = parser.parse_args()
  105. B, T = args.batch_size, args.sequence_length
  106. assert 1 <= T <= 1024
  107. model = GPT(GPTConfig(n_layer=12, n_head=12, n_embd=768))
  108. model.load_pretrained()
  109. # init the tokenizer
  110. enc = tiktoken.get_encoding("gpt2")
  111. encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
  112. decode = lambda l: enc.decode(l)
  113. # load the tokens
  114. # prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
  115. # we're using val instead of train split just because it is smaller/faster
  116. shake_tokens_bin = "data/tiny_shakespeare_val.bin"
  117. story_tokens_bin = "data/TinyStories_val.bin"
  118. assert os.path.isfile(shake_tokens_bin) or os.path.isfile(story_tokens_bin), "you must run prepro on some dataset"
  119. tokens_bin = shake_tokens_bin if os.path.isfile(shake_tokens_bin) else story_tokens_bin
  120. assert os.path.isfile(tokens_bin)
  121. print(f"loading cached tokens in {tokens_bin}")
  122. with open(tokens_bin, "rb") as f:
  123. f.seek(0x400)
  124. tokens = np.frombuffer(f.read(), dtype=np.uint16).astype(np.int32)
  125. tokens = Tensor(tokens)
  126. # lightweight dataloader
  127. def get_batch():
  128. assert B*T+1 <= len(tokens), "not enough tokens"
  129. # for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping
  130. i = 0
  131. while True:
  132. x = tokens[i:i+B*T].view(B, T)
  133. y = tokens[i+1:i+B*T+1].view(B, T)
  134. yield x, y
  135. i += B*T
  136. if i + B*T + 1 >= len(tokens):
  137. i = 0 # in prod we'd want to randomize the start point a bit
  138. # forward backward for a few iterations
  139. data_iter = iter(get_batch())
  140. x, y = next(data_iter) # we'll overfit this batch below
  141. optimizer = nn.optim.AdamW(nn.state.get_parameters(model), lr=1e-4, weight_decay=0)
  142. @TinyJit
  143. def step(x, y):
  144. _, loss = model(x, y)
  145. optimizer.zero_grad()
  146. loss.backward()
  147. optimizer.step()
  148. return loss
  149. with Tensor.train():
  150. for i in range(args.num_iterations):
  151. GlobalCounters.reset()
  152. t0 = time.time()
  153. loss = step(x.contiguous(), y.contiguous())
  154. Device[Device.DEFAULT].synchronize()
  155. t1 = time.time()
  156. print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")
  157. start = "<|endoftext|>"
  158. start_ids = encode(start)
  159. x = (Tensor(start_ids)[None, ...])
  160. max_new_tokens = 16
  161. temperature = 1.0
  162. top_k = 40
  163. y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
  164. print(decode(y[0].tolist()))