inference.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from pathlib import Path
  2. from typing import List
  3. import json, argparse, random, time
  4. import tiktoken
  5. from tiktoken.load import load_tiktoken_bpe
  6. from inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
  7. from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
  8. from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
  9. from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
  10. from inference.shard import Shard
  11. from inference.inference_engine import InferenceEngine
  12. import numpy as np
  13. MODEL_PARAMS = {
  14. "8B": {
  15. "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
  16. "files": 1
  17. },
  18. "70B": {
  19. "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672},
  20. "files": 8
  21. }
  22. }
  23. class Tokenizer:
  24. pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
  25. def __init__(self, model_path: str):
  26. mergeable_ranks = load_tiktoken_bpe(model_path)
  27. self.num_base_tokens = len(mergeable_ranks)
  28. special_tokens = [
  29. "<|begin_of_text|>",
  30. "<|end_of_text|>",
  31. "<|reserved_special_token_0|>",
  32. "<|reserved_special_token_1|>",
  33. "<|reserved_special_token_2|>",
  34. "<|reserved_special_token_3|>",
  35. "<|start_header_id|>",
  36. "<|end_header_id|>",
  37. "<|reserved_special_token_4|>",
  38. "<|eot_id|>",
  39. ] + [
  40. f"<|reserved_special_token_{i}|>"
  41. for i in range(5, 256 - 5)
  42. ]
  43. self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
  44. self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
  45. @property
  46. def bos_id(self): return self.special_tokens["<|begin_of_text|>"]
  47. @property
  48. def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
  49. def decode(self, toks): return self.model.decode([t for t in toks if t < self.num_base_tokens])
  50. def encode(self, text, allow_special=False):
  51. return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
  52. # **** helper functions ****
  53. def concat_weights(models, device=None):
  54. def convert(name) -> Tensor:
  55. disk_tensors: List[Tensor] = [model[name] for model in models]
  56. if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
  57. return disk_tensors[0].to(device=device)
  58. axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
  59. lazy_tensors = [data.to(device=device) for data in disk_tensors]
  60. return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
  61. return {name: convert(name) for name in {name: None for model in models for name in model}}
  62. def load(fn:str):
  63. if fn.endswith('.index.json'):
  64. with open(fn) as fp: weight_map = json.load(fp)['weight_map']
  65. parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
  66. return {k: parts[n][k] for k, n in weight_map.items()}
  67. elif fn.endswith(".safetensors"):
  68. return safe_load(fn)
  69. else:
  70. return torch_load(fn)
  71. def build_transformer(model_path: Path, model_size="8B", quantize=None, device=None):
  72. # build model
  73. linear = nn.Linear
  74. with Context(THREEFRY=0):
  75. model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
  76. # load weights
  77. if model_path.is_dir():
  78. if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"))
  79. elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"))
  80. else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
  81. else:
  82. weights = load(str(model_path))
  83. if "model.embed_tokens.weight" in weights:
  84. weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
  85. weights = fix_bf16(weights)
  86. with Context(BEAM=0):
  87. # quantize
  88. if quantize is not None:
  89. weights = linear.quantize(weights, device)
  90. for _,v in weights.items(): v.realize()
  91. # shard
  92. if isinstance(device, tuple):
  93. for k,v in nn.state.get_state_dict(model).items():
  94. if 'scale' in k: v.shard_(device, axis=None) # from quantized
  95. elif '.attention.' in k: v.shard_(device, axis=-1)
  96. elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
  97. elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
  98. elif '.feed_forward.' in k: v.shard_(device, axis=-1)
  99. elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
  100. elif 'output.weight' in k: v.shard_(device, axis=0)
  101. else: v.shard_(device, axis=None)
  102. # replace weights in model
  103. load_state_dict(model, weights, strict=False, consume=True)
  104. return model
  105. # default settings
  106. TEMPERATURE = 0.85
  107. TOP_K = 25
  108. TOP_P = 0.9
  109. ALPHA_F = 0.1
  110. ALPHA_P = 0.0
  111. last_seen_toks = []
  112. def prefill(model, toks, start_pos=0):
  113. global last_seen_toks
  114. # we can skip part of the prompt if it is the same as last and start_pos=0
  115. if start_pos == 0:
  116. for i, (a, b) in enumerate(zip(toks, last_seen_toks)):
  117. if a != b: break
  118. else: i = min(len(toks), len(last_seen_toks))
  119. start_pos += i
  120. last_seen_toks = toks
  121. toks = toks[i:]
  122. # prefill the model
  123. for tok in tqdm(toks):
  124. GlobalCounters.reset()
  125. model(Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
  126. start_pos += 1
  127. return start_pos
  128. class TinygradDynamicShardInferenceEngine(InferenceEngine):
  129. def __init__(self):
  130. self.shard = None
  131. async def infer_prompt(self, shard: Shard, prompt: str) -> (np.ndarray, bool):
  132. def encode_role(role: str):
  133. return [self.tokenizer.special_tokens["<|start_header_id|>"]] + self.tokenizer.encode(role) + [self.tokenizer.special_tokens["<|end_header_id|>"]] + self.tokenizer.encode("\n\n")
  134. def encode_message(role: str, content: str):
  135. return encode_role(role) + self.tokenizer.encode(content.strip()) + [self.tokenizer.special_tokens["<|eot_id|>"]]
  136. await self.ensure_shard(shard)
  137. print([self.tokenizer.encode(prompt)])
  138. toks = [self.tokenizer.bos_id] + encode_message("user", prompt) + encode_role("assistant")
  139. start_pos = prefill(self.model, toks[:-1])
  140. last_tok = toks[-1]
  141. output_data = np.array(self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
  142. start_pos += 1
  143. return output_data, output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
  144. async def infer_tensor(self, shard: Shard, input_data: np.ndarray) -> (np.ndarray, bool):
  145. await self.ensure_shard(shard)
  146. output_data: np.ndarray = np.array(self.model(Tensor([input_data]), 0, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist())
  147. return output_data, output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
  148. async def reset_shard(self, shard: Shard):
  149. await self.ensure_shard(shard)
  150. print(f"Resetting shard: {shard}")
  151. self.model.reset()
  152. async def ensure_shard(self, shard: Shard):
  153. if self.shard == shard:
  154. return
  155. model_path = Path(shard.model_id)
  156. size = "8B" # one of 8B or 70B for now
  157. model = build_transformer(model_path, model_size=size)
  158. tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
  159. self.shard = shard
  160. self.model = model
  161. self.tokenizer = tokenizer