inference.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from pathlib import Path
  2. from typing import List, Optional
  3. import json, argparse, random, time
  4. import tiktoken
  5. from tiktoken.load import load_tiktoken_bpe
  6. from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
  7. from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
  8. from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
  9. from tinygrad.helpers import DEBUG, tqdm, _cache_dir
  10. from exo.inference.shard import Shard
  11. from exo.inference.inference_engine import InferenceEngine
  12. import numpy as np
  13. MODEL_PARAMS = {
  14. "8B": {
  15. "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
  16. "files": 1
  17. },
  18. "70B": {
  19. "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672},
  20. "files": 8
  21. }
  22. }
  23. class Tokenizer:
  24. pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
  25. def __init__(self, model_path: str):
  26. mergeable_ranks = load_tiktoken_bpe(model_path)
  27. self.num_base_tokens = len(mergeable_ranks)
  28. special_tokens = [
  29. "<|begin_of_text|>",
  30. "<|end_of_text|>",
  31. "<|reserved_special_token_0|>",
  32. "<|reserved_special_token_1|>",
  33. "<|reserved_special_token_2|>",
  34. "<|reserved_special_token_3|>",
  35. "<|start_header_id|>",
  36. "<|end_header_id|>",
  37. "<|reserved_special_token_4|>",
  38. "<|eot_id|>",
  39. ] + [
  40. f"<|reserved_special_token_{i}|>"
  41. for i in range(5, 256 - 5)
  42. ]
  43. self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
  44. self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
  45. @property
  46. def bos_id(self): return self.special_tokens["<|begin_of_text|>"]
  47. @property
  48. def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
  49. def decode(self, toks):
  50. return self.model.decode([t for t in toks if t < self.num_base_tokens])
  51. def encode(self, text, allow_special=False):
  52. return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
  53. # **** helper functions ****
  54. def concat_weights(models, device=None):
  55. def convert(name) -> Tensor:
  56. disk_tensors: List[Tensor] = [model[name] for model in models]
  57. if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
  58. return disk_tensors[0].to(device=device)
  59. axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
  60. lazy_tensors = [data.to(device=device) for data in disk_tensors]
  61. return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
  62. return {name: convert(name) for name in {name: None for model in models for name in model}}
  63. def load(fn:str):
  64. if fn.endswith('.index.json'):
  65. with open(fn) as fp: weight_map = json.load(fp)['weight_map']
  66. parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
  67. return {k: parts[n][k] for k, n in weight_map.items()}
  68. elif fn.endswith(".safetensors"):
  69. return safe_load(fn)
  70. else:
  71. return torch_load(fn)
  72. def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
  73. # build model
  74. linear = nn.Linear
  75. with Context(THREEFRY=0):
  76. model = Transformer(**MODEL_PARAMS[model_size]["args"], shard=shard, linear=linear, max_context=8192, jit=False)
  77. # load weights
  78. if model_path.is_dir():
  79. if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"))
  80. elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"))
  81. else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
  82. else:
  83. weights = load(str(model_path))
  84. if "model.embed_tokens.weight" in weights:
  85. weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"], shard=shard)
  86. weights = fix_bf16(weights)
  87. with Context(BEAM=0):
  88. # quantize
  89. if quantize is not None:
  90. weights = linear.quantize(weights, device)
  91. for _,v in weights.items(): v.realize()
  92. # shard
  93. if isinstance(device, tuple):
  94. for k,v in nn.state.get_state_dict(model).items():
  95. if 'scale' in k: v.shard_(device, axis=None) # from quantized
  96. elif '.attention.' in k: v.shard_(device, axis=-1)
  97. elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
  98. elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
  99. elif '.feed_forward.' in k: v.shard_(device, axis=-1)
  100. elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
  101. elif 'output.weight' in k: v.shard_(device, axis=0)
  102. else: v.shard_(device, axis=None)
  103. # replace weights in model
  104. load_state_dict(model, weights, strict=False, consume=True)
  105. return model
  106. # default settings
  107. TEMPERATURE = 0 # 0.85
  108. TOP_K = 25
  109. TOP_P = 0.9
  110. ALPHA_F = 0.1
  111. ALPHA_P = 0.0
  112. def prefill(model, toks, start_pos=0):
  113. # prefill the model
  114. for tok in tqdm(toks):
  115. GlobalCounters.reset()
  116. model(Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
  117. start_pos += 1
  118. return start_pos
  119. class TinygradDynamicShardInferenceEngine(InferenceEngine):
  120. def __init__(self):
  121. self.shard = None
  122. async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  123. def encode_role(role: str):
  124. return [self.tokenizer.special_tokens["<|start_header_id|>"]] + self.tokenizer.encode(role) + [self.tokenizer.special_tokens["<|end_header_id|>"]] + self.tokenizer.encode("\n\n")
  125. def encode_message(role: str, content: str):
  126. return encode_role(role) + self.tokenizer.encode(content.strip()) + [self.tokenizer.special_tokens["<|eot_id|>"]]
  127. await self.ensure_shard(shard)
  128. start_pos = json.loads(inference_state)["start_pos"] if inference_state else 0
  129. toks = [self.tokenizer.bos_id] + encode_message("user", prompt) + encode_role("assistant")
  130. start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
  131. last_tok = toks[-1]
  132. output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
  133. if output_data.size == 1:
  134. start_pos += 1
  135. return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
  136. async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
  137. await self.ensure_shard(shard)
  138. start_pos = json.loads(inference_state)["start_pos"] if inference_state else 0
  139. output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
  140. if output_data.size == 1:
  141. start_pos += 1
  142. return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
  143. async def reset_shard(self, shard: Shard):
  144. await self.ensure_shard(shard)
  145. self.model.reset()
  146. async def ensure_shard(self, shard: Shard):
  147. if self.shard == shard:
  148. return
  149. model_path = Path(shard.model_id)
  150. models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
  151. model_path = models_dir / shard.model_id
  152. size = "8B"
  153. if model_path.exists():
  154. model = model_path
  155. else:
  156. from tinygrad.helpers import fetch
  157. if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
  158. if shard.model_id.lower().find("llama3-8b-sfr") != -1:
  159. fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
  160. fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
  161. fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
  162. fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
  163. fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
  164. model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
  165. size = "8B"
  166. elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
  167. raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
  168. # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
  169. # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
  170. # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
  171. # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
  172. # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
  173. # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
  174. # size = "70B"
  175. else:
  176. raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
  177. model = build_transformer(model_path, shard=shard, model_size=size)
  178. tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
  179. self.shard = shard
  180. self.model = model
  181. self.tokenizer = tokenizer