llama3.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. from pathlib import Path
  2. from typing import List
  3. import json, argparse, random, time
  4. import tiktoken
  5. from tiktoken.load import load_tiktoken_bpe
  6. from extra.models.llama import Transformer, convert_from_huggingface, fix_bf16
  7. from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
  8. from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
  9. from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
  10. class Tokenizer:
  11. pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
  12. def __init__(self, model_path: str):
  13. mergeable_ranks = load_tiktoken_bpe(model_path)
  14. self.num_base_tokens = len(mergeable_ranks)
  15. special_tokens = [
  16. "<|begin_of_text|>",
  17. "<|end_of_text|>",
  18. "<|reserved_special_token_0|>",
  19. "<|reserved_special_token_1|>",
  20. "<|reserved_special_token_2|>",
  21. "<|reserved_special_token_3|>",
  22. "<|start_header_id|>",
  23. "<|end_header_id|>",
  24. "<|reserved_special_token_4|>",
  25. "<|eot_id|>",
  26. ] + [
  27. f"<|reserved_special_token_{i}|>"
  28. for i in range(5, 256 - 5)
  29. ]
  30. self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
  31. self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
  32. @property
  33. def bos_id(self): return self.special_tokens["<|begin_of_text|>"]
  34. @property
  35. def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
  36. def decode(self, toks): return self.model.decode([t for t in toks if t < self.num_base_tokens])
  37. def encode(self, text, allow_special=False):
  38. return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
  39. # **** helper functions ****
  40. def concat_weights(models, device=None):
  41. def convert(name) -> Tensor:
  42. disk_tensors: List[Tensor] = [model[name] for model in models]
  43. if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
  44. return disk_tensors[0].to(device=device)
  45. axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
  46. lazy_tensors = [data.to(device=device) for data in disk_tensors]
  47. return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
  48. return {name: convert(name) for name in {name: None for model in models for name in model}}
  49. def load(fn:str):
  50. if fn.endswith('.index.json'):
  51. with open(fn) as fp: weight_map = json.load(fp)['weight_map']
  52. parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
  53. return {k: parts[n][k] for k, n in weight_map.items()}
  54. elif fn.endswith(".safetensors"):
  55. return safe_load(fn)
  56. else:
  57. return torch_load(fn)
  58. # **** quantized linears ****
  59. class Int8Linear:
  60. def __init__(self, in_features, out_features, bias=False):
  61. assert bias == False
  62. self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.int8)
  63. self.scale = Tensor.ones(out_features, dtype=dtypes.half)
  64. def __call__(self, x):
  65. return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale)
  66. @staticmethod
  67. def quantize(tensors, device):
  68. new_tensors = {}
  69. for name,v in tensors.items():
  70. if "feed_forward" in name or "attention.w" in name:
  71. assert "weight" in name, name
  72. scale = v.abs().max(axis=1) / 127.0
  73. int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
  74. new_tensors[name] = int8_weight
  75. new_tensors[name.replace('weight', 'scale')] = scale
  76. if isinstance(device, tuple):
  77. new_tensors[name].shard_(device, axis=-1)
  78. new_tensors[name.replace('weight', 'scale')].shard_(device, axis=None)
  79. else:
  80. new_tensors[name] = v
  81. return new_tensors
  82. def NF4Linear(block_size):
  83. _CODE = [
  84. -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
  85. 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0,
  86. ]
  87. CODE = Tensor.stack(*[Tensor(c, dtype=dtypes.float16) for c in _CODE])
  88. class _NF4Linear:
  89. def __init__(self, in_features, out_features, bias=False):
  90. assert not bias, "bias not supported"
  91. self.in_features, self.out_features = in_features, out_features
  92. self.weight = Tensor.empty(int(out_features * in_features / 2), dtype=dtypes.uint8)
  93. self.scale = Tensor.empty(int(out_features * in_features / block_size), 1, dtype=dtypes.float16)
  94. def __call__(self, x: Tensor) -> Tensor:
  95. high_bits = self.weight
  96. low_bits = (self.weight * 2 ** 4).contiguous()
  97. unpacked = Tensor.stack(high_bits, low_bits, dim=-1).div(2 ** 4, upcast=False)
  98. unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
  99. return x.linear(unscaled.reshape(self.out_features, self.in_features).T)
  100. @staticmethod
  101. def quantize(state_dict: dict[str, Tensor], device) -> dict[str, Tensor]:
  102. new_state_dict = {}
  103. for k, v in state_dict.items():
  104. if "feed_forward" in k or "attention.w" in k:
  105. grouped = v.reshape(-1, block_size)
  106. scale = (grouped.abs().max(axis=1, keepdim=True))
  107. coded = ((grouped / scale).unsqueeze(-1) - CODE.to(v.device)).abs().argmin(axis=-1).cast(dtypes.uint8).flatten()
  108. new_state_dict[k] = coded[::2] * 2 ** 4 + coded[1::2]
  109. new_state_dict[k.replace(".weight", ".scale")] = scale.cast(dtypes.float16)
  110. if isinstance(device, tuple):
  111. new_state_dict[k].shard_(device, axis=-1)
  112. new_state_dict[k.replace('weight', 'scale')].shard_(device, axis=None)
  113. else:
  114. new_state_dict[k] = v
  115. return new_state_dict
  116. return _NF4Linear
  117. MODEL_PARAMS = {
  118. "8B": {
  119. "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
  120. "files": 1
  121. },
  122. "70B": {
  123. "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672},
  124. "files": 8
  125. }
  126. }
  127. def build_transformer(model_path: Path, model_size="8B", quantize=None, device=None):
  128. # build model
  129. if quantize == "int8": linear = Int8Linear
  130. elif quantize == "nf4": linear = NF4Linear(64)
  131. else: linear = nn.Linear
  132. with Context(THREEFRY=0):
  133. model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True)
  134. # load weights
  135. if model_path.is_dir():
  136. if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"))
  137. elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"))
  138. else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
  139. else:
  140. weights = load(str(model_path))
  141. if "model.embed_tokens.weight" in weights:
  142. weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
  143. weights = fix_bf16(weights)
  144. with Context(BEAM=0):
  145. # quantize
  146. if quantize is not None:
  147. weights = linear.quantize(weights, device)
  148. for _,v in weights.items(): v.realize()
  149. # shard
  150. if isinstance(device, tuple):
  151. for k,v in nn.state.get_state_dict(model).items():
  152. if 'scale' in k: v.shard_(device, axis=None) # from quantized
  153. elif '.attention.' in k: v.shard_(device, axis=-1)
  154. elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
  155. elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
  156. elif '.feed_forward.' in k: v.shard_(device, axis=-1)
  157. elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
  158. elif 'output.weight' in k: v.shard_(device, axis=0)
  159. else: v.shard_(device, axis=None)
  160. # replace weights in model
  161. load_state_dict(model, weights, strict=False, consume=True)
  162. return model
  163. # default settings
  164. TEMPERATURE = 0.85
  165. TOP_K = 25
  166. TOP_P = 0.9
  167. ALPHA_F = 0.1
  168. ALPHA_P = 0.0
  169. last_seen_toks = []
  170. def prefill(model, toks, start_pos=0):
  171. global last_seen_toks
  172. # we can skip part of the prompt if it is the same as last and start_pos=0
  173. if start_pos == 0:
  174. for i, (a, b) in enumerate(zip(toks, last_seen_toks)):
  175. if a != b: break
  176. else: i = min(len(toks), len(last_seen_toks))
  177. start_pos += i
  178. last_seen_toks = toks
  179. toks = toks[i:]
  180. # prefill the model
  181. for tok in tqdm(toks):
  182. GlobalCounters.reset()
  183. model(Tensor([[tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).realize()
  184. start_pos += 1
  185. return start_pos
  186. if __name__ == "__main__":
  187. Tensor.no_grad = True
  188. parser = argparse.ArgumentParser()
  189. parser.add_argument("--download_model", action="store_true", help="Download a 8B model")
  190. parser.add_argument("--model", type=Path, help="Model path")
  191. parser.add_argument("--size", choices=["8B", "70B"], default="8B", help="Model size")
  192. parser.add_argument("--shard", type=int, default=1, help="Shard the model across multiple devices")
  193. parser.add_argument("--quantize", choices=["int8", "nf4"], help="Quantization method")
  194. parser.add_argument("--no_api", action="store_true", help="Disable the api and run a cli test interface")
  195. parser.add_argument("--host", type=str, default="0.0.0.0", help="Web server bind address")
  196. parser.add_argument("--port", type=int, default=7776, help="Web server port")
  197. parser.add_argument("--debug", action="store_true", help="Enable debug mode")
  198. parser.add_argument("--seed", type=int, help="Random seed")
  199. parser.add_argument("--benchmark", action="store_true", help="Run a benchmark")
  200. parser.add_argument("--timing", action="store_true", help="Print timing per token")
  201. parser.add_argument("--profile", action="store_true", help="Output profile data")
  202. args = parser.parse_args()
  203. assert not (args.download_model and args.model), "either download or provide model"
  204. if args.download_model:
  205. fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-8b-sfr")
  206. fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir="llama3-8b-sfr")
  207. fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir="llama3-8b-sfr")
  208. fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir="llama3-8b-sfr")
  209. fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir="llama3-8b-sfr")
  210. args.model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir="llama3-8b-sfr")
  211. assert args.model is not None, "please provide --model option"
  212. if args.seed is not None: Tensor.manual_seed(args.seed)
  213. if args.benchmark: Tensor.manual_seed(42)
  214. print(f"seed = {Tensor._seed}")
  215. tokenizer = Tokenizer(str((args.model if args.model.is_dir() else args.model.parent) / "tokenizer.model"))
  216. def encode_role(role: str):
  217. return [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode(role) + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
  218. def encode_message(role: str, content: str):
  219. return encode_role(role) + tokenizer.encode(content.strip()) + [tokenizer.special_tokens["<|eot_id|>"]]
  220. device = tuple(f"{Device.DEFAULT}:{i}" for i in range(args.shard)) if args.shard > 1 else Device.DEFAULT
  221. model = build_transformer(args.model, model_size=args.size, quantize=args.quantize, device=device)
  222. param_bytes = sum(x.lazydata.size * x.dtype.itemsize for x in get_parameters(model))
  223. if not args.no_api and not args.benchmark:
  224. from bottle import Bottle, request, response, HTTPResponse, abort, static_file
  225. app = Bottle()
  226. cors_headers = {
  227. "Access-Control-Allow-Origin": "*",
  228. "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
  229. "Access-Control-Allow-Headers": "Origin, Accept, Content-Type, X-Requested-With, X-CSRF-Token, Authorization",
  230. "Access-Control-Allow-Credentials": "true",
  231. }
  232. @app.hook("before_request")
  233. def handle_options():
  234. if request.method == "OPTIONS": raise HTTPResponse(headers=cors_headers)
  235. @app.hook("after_request")
  236. def enable_cors():
  237. for key, value in cors_headers.items(): response.set_header(key, value)
  238. @app.route("/<filename>")
  239. def server_static(filename):
  240. return static_file(filename, root=(Path(__file__).parent / "tinychat").as_posix())
  241. @app.route("/")
  242. def index():
  243. return static_file("index.html", root=(Path(__file__).parent / "tinychat").as_posix())
  244. @app.get("/v1/models")
  245. def models():
  246. return json.dumps([str(args.model)])
  247. @app.post("/v1/internal/token-count")
  248. def token_count():
  249. rjson = json.loads(request.body.read())
  250. return json.dumps(len(tokenizer.encode(rjson.get("text", ""))))
  251. @app.post("/v1/token/encode")
  252. def token_encode():
  253. rjson = json.loads(request.body.read())
  254. return json.dumps(tokenizer.encode(rjson.get("text", "")))
  255. @app.post("/v1/completions")
  256. def completions():
  257. rjson = json.loads(request.body.read())
  258. # check if we are streaming
  259. if rjson.get("stream", False):
  260. response.content_type = "text/event-stream"
  261. response.set_header("Cache-Control", "no-cache")
  262. else: abort(400, "streaming required")
  263. toks = [tokenizer.bos_id] + tokenizer.encode(rjson.get("prompt", ""), allow_special=True)
  264. start_pos = prefill(model, toks[:-1])
  265. last_tok = toks[-1]
  266. while True:
  267. GlobalCounters.reset()
  268. tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).item()
  269. start_pos += 1
  270. last_tok = tok
  271. if tok in tokenizer.stop_tokens: break
  272. res = {
  273. "choices": [{
  274. "text": tokenizer.decode([tok]),
  275. }]
  276. }
  277. yield f"data: {json.dumps(res)}\n\n"
  278. @app.post("/v1/chat/token/encode")
  279. def chat_token_encode():
  280. rjson = json.loads(request.body.read())
  281. if "messages" not in rjson: abort(400, "messages required")
  282. toks = [tokenizer.bos_id]
  283. for message in rjson["messages"]:
  284. toks += encode_message(message["role"], message["content"])
  285. if len(rjson["messages"]) > 0 and message["role"] == "user":
  286. toks += encode_role("assistant")
  287. return json.dumps(toks)
  288. @app.post("/v1/chat/completions")
  289. def chat_completions():
  290. global last_seen_toks
  291. rjson = json.loads(request.body.read())
  292. if "messages" not in rjson: abort(400, "messages required")
  293. # check if we are streaming
  294. if rjson.get("stream", False):
  295. response.content_type = "text/event-stream"
  296. response.set_header("Cache-Control", "no-cache")
  297. else: abort(400, "streaming required")
  298. toks = [tokenizer.bos_id]
  299. for message in rjson["messages"]:
  300. toks += encode_message(message["role"], message["content"])
  301. # ensure that the last message was a user message
  302. if message["role"] != "user": abort(400, "last message must be a user message")
  303. toks += encode_role("assistant")
  304. random_id = random.randbytes(16).hex()
  305. start_pos = prefill(model, toks[:-1])
  306. last_tok = toks[-1]
  307. last_seen_toks.append(last_tok)
  308. while True:
  309. GlobalCounters.reset()
  310. tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).item()
  311. start_pos += 1
  312. last_tok = tok
  313. last_seen_toks.append(tok)
  314. if tok in tokenizer.stop_tokens: break
  315. res = {
  316. "id": random_id,
  317. "object": "chat.completion.chunk",
  318. "created": int(time.time()),
  319. "model": str(args.model),
  320. "choices": [{
  321. "index": 0,
  322. "delta": {
  323. "role": "assistant",
  324. "content": tokenizer.decode([tok]),
  325. },
  326. "finish_reason": None,
  327. }]
  328. }
  329. yield f"data: {json.dumps(res)}\n\n"
  330. res = {
  331. "id": random_id,
  332. "object": "chat.completion.chunk",
  333. "created": int(time.time()),
  334. "model": str(args.model),
  335. "choices": [{
  336. "index": 0,
  337. "delta": {},
  338. "finish_reason": "stop",
  339. }]
  340. }
  341. yield f"data: {json.dumps(res)}\n\n"
  342. app.run(host=args.host, port=args.port, debug=args.debug)
  343. elif args.benchmark:
  344. toks = [tokenizer.bos_id] + encode_message("user", "Hello.") + encode_role("assistant")
  345. start_pos = prefill(model, toks[:-1])
  346. last_tok = toks[-1]
  347. generated = ""
  348. for _ in range(20):
  349. GlobalCounters.reset()
  350. st = GlobalCounters.time_sum_s
  351. with Profiling(enabled=args.profile):
  352. with Timing("total ", on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
  353. with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
  354. f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
  355. (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None):
  356. tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P)
  357. tok = tok.item()
  358. start_pos += 1
  359. last_tok = tok
  360. generated += tokenizer.decode([tok])
  361. print(generated)
  362. if "LLaMA-3/8B-SF-DPO" in args.model.as_posix():
  363. EXPECTED_TEXT = {
  364. 1: "Hello! How can I help you today? If you have any questions or need assistance with anything,",
  365. 2: "Hello! How can I help you today? If you have any questions, need assistance or just want",
  366. 3: "Hello! How can I help you today? If you have any questions or need assistance, feel free",
  367. 4: "Hello! How can I assist you today? If you have any questions, need information, or require",
  368. 5: "Hello! How can I assist you today? If you have any questions or need help with something",
  369. 6: "Hello! How can I assist you today? If you have any questions, need information, or require",
  370. }
  371. assert generated == EXPECTED_TEXT[args.shard], f"{generated=} {EXPECTED_TEXT[args.shard]}"
  372. print("\n" + colored("output validated", "green")) # NOTE: "\n" inside colored does not render the color in github action
  373. else:
  374. prompt = [tokenizer.bos_id] + encode_message("system", "You are an helpful assistant.")
  375. start_pos = prefill(model, prompt)
  376. while True:
  377. toks = encode_message("user", input("Q: ")) + encode_role("assistant")
  378. start_pos = prefill(model, toks[:-1], start_pos=start_pos)
  379. last_tok = toks[-1]
  380. while True:
  381. GlobalCounters.reset()
  382. if args.timing or args.profile: print("")
  383. st = GlobalCounters.time_sum_s
  384. with Profiling(enabled=args.profile):
  385. with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/s, {GlobalCounters.global_mem/x:.2f} GB/s, param {param_bytes/x:.2f} GB/s"):
  386. with Timing("enqueue in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
  387. f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
  388. (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_bytes*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
  389. tok = model(Tensor([[last_tok]], device=device), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P)
  390. tok = tok.item()
  391. start_pos += 1
  392. last_tok = tok
  393. if tok in tokenizer.stop_tokens: break
  394. print(tokenizer.decode([tok]), end="", flush=True)
  395. print(flush=True)