conversation.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import argparse
  2. import multiprocessing as mp
  3. import os
  4. import re
  5. import sys
  6. import time
  7. from contextlib import contextmanager
  8. from pathlib import Path
  9. import numpy as np
  10. import pyaudio
  11. import yaml
  12. from llama import LLaMa
  13. from vits import MODELS as VITS_MODELS
  14. from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model
  15. from whisper import init_whisper, transcribe_waveform
  16. from sentencepiece import SentencePieceProcessor
  17. from tinygrad.helpers import Timing, fetch
  18. from tinygrad import Tensor, dtypes
  19. # Whisper constants
  20. RATE = 16000
  21. CHUNK = 1600
  22. # LLaMa constants
  23. IM_START = 32001
  24. IM_END = 32002
  25. # Functions for encoding prompts to chatml md
  26. def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
  27. def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n")
  28. def chunks(lst, n):
  29. for i in range(0, len(lst), n): yield lst[i:i + n]
  30. def create_fixed_tokenizer():
  31. """Function needed for extending tokenizer with additional chat tokens"""
  32. import extra.junk.sentencepiece_model_pb2 as spb2
  33. tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model")
  34. if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
  35. print("creating fixed tokenizer")
  36. mp = spb2.ModelProto()
  37. mp.ParseFromString(tokenizer_path.read_bytes())
  38. # https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json
  39. mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0))
  40. mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
  41. mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
  42. tokenizer_path.write_bytes(mp.SerializeToString())
  43. return tokenizer_path
  44. def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]:
  45. """Prepares a llama model from a specified pre-prompt file"""
  46. with open(str(pre_prompt_path)) as f:
  47. config = yaml.safe_load(f.read())
  48. toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " "))
  49. for i in config["examples"]:
  50. toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
  51. toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
  52. llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
  53. return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks)
  54. def llama_generate(
  55. llama: LLaMa,
  56. toks: list[int],
  57. outputted: str,
  58. prompt: str,
  59. start_pos: int,
  60. user_delim: str,
  61. resp_delim: str,
  62. temperature=0.7,
  63. max_tokens=1000
  64. ):
  65. """Generates an output for the specified prompt"""
  66. toks += encode_prompt(llama.tokenizer, user_delim, prompt)
  67. toks += start_prompt(llama.tokenizer, resp_delim)
  68. outputted = llama.tokenizer.decode(toks)
  69. init_length = len(outputted)
  70. for _ in range(max_tokens):
  71. token = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
  72. start_pos = len(toks)
  73. toks.append(token)
  74. cur = llama.tokenizer.decode(toks)
  75. # Print is just for debugging
  76. sys.stdout.write(cur[len(outputted):])
  77. sys.stdout.flush()
  78. outputted = cur
  79. if toks[-1] == IM_END: break
  80. else:
  81. toks.append(IM_END)
  82. print() # because the output is flushed
  83. return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
  84. def tts(
  85. text_to_synthesize: str,
  86. synth: Synthesizer,
  87. hps: HParams,
  88. emotion_embedding: Path,
  89. speaker_id: int,
  90. model_to_use: str,
  91. noise_scale: float,
  92. noise_scale_w: float,
  93. length_scale: float,
  94. estimate_max_y_length: bool,
  95. text_mapper: TextMapper,
  96. model_has_multiple_speakers: bool,
  97. pad_length=600,
  98. vits_pad_length=1000
  99. ):
  100. if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
  101. # Convert the input text to a tensor.
  102. stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
  103. init_shape = stn_tst.shape
  104. assert init_shape[0] < pad_length, "text is too long"
  105. x_tst, x_tst_lengths = stn_tst.pad(((0, pad_length - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
  106. sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
  107. # Perform inference.
  108. audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
  109. max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, pad_length=vits_pad_length)[0, 0]
  110. # Save the audio output.
  111. audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
  112. return audio_data
  113. def init_vits(
  114. model_to_use: str,
  115. emotion_path: Path,
  116. speaker_id: int,
  117. seed: int,
  118. ):
  119. model_config = VITS_MODELS[model_to_use]
  120. # Load the hyperparameters from the config file.
  121. hps = get_hparams_from_file(fetch(model_config[0]))
  122. # If model has multiple speakers, validate speaker id and retrieve name if available.
  123. model_has_multiple_speakers = hps.data.n_speakers > 0
  124. if model_has_multiple_speakers:
  125. if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
  126. if hps.__contains__("speakers"): # maps speaker ids to names
  127. speakers = hps.speakers
  128. if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)}
  129. # Load emotions if any. TODO: find an english model with emotions, this is untested atm.
  130. emotion_embedding = None
  131. if emotion_path is not None:
  132. if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0)
  133. else: raise ValueError("Emotion path must be a .npy file.")
  134. # Load symbols, instantiate TextMapper and clean the text.
  135. if hps.__contains__("symbols"): symbols = hps.symbols
  136. elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
  137. else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ")
  138. text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
  139. # Load the model.
  140. Tensor.no_grad = True
  141. if seed is not None:
  142. Tensor.manual_seed(seed)
  143. np.random.seed(seed)
  144. net_g = load_model(text_mapper.symbols, hps, model_config)
  145. return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
  146. @contextmanager
  147. def output_stream(num_channels: int, sample_rate: int):
  148. try:
  149. p = pyaudio.PyAudio()
  150. stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True)
  151. yield stream
  152. except KeyboardInterrupt: pass
  153. finally:
  154. stream.stop_stream()
  155. stream.close()
  156. p.terminate()
  157. @contextmanager
  158. def log_writer():
  159. try:
  160. logs = []
  161. yield logs
  162. finally:
  163. sep = "="*os.get_terminal_size()[1]
  164. print(f"{sep[:-1]}\nCHAT LOG")
  165. print(*logs, sep="\n")
  166. print(sep)
  167. def listener(q: mp.Queue, event: mp.Event):
  168. try:
  169. p = pyaudio.PyAudio()
  170. stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
  171. did_print = False
  172. while True:
  173. data = stream.read(CHUNK) # read data to avoid overflow
  174. if event.is_set():
  175. if not did_print:
  176. print("listening")
  177. did_print = True
  178. q.put(((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3))
  179. else:
  180. did_print = False
  181. finally:
  182. stream.stop_stream()
  183. stream.close()
  184. p.terminate()
  185. def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int):
  186. with output_stream(num_channels, sample_rate) as stream:
  187. while True:
  188. try:
  189. stream.write(q.get())
  190. counter.value += 1
  191. except KeyboardInterrupt:
  192. break
  193. if __name__ == "__main__":
  194. import nltk
  195. nltk.download("punkt")
  196. Tensor.no_grad = True
  197. # Parse CLI arguments
  198. parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad")
  199. # Whisper args
  200. parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
  201. # LLAMA args
  202. parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ")
  203. parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate")
  204. parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax")
  205. parser.add_argument("--llama_quantize", type=str, default=None, help="Quantize the weights to int8 or nf4 in memory")
  206. parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
  207. parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use")
  208. parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")
  209. parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model")
  210. # vits args
  211. parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
  212. parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.")
  213. parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
  214. parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
  215. parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
  216. parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.")
  217. parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
  218. parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
  219. parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.")
  220. parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
  221. parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.")
  222. # conversation args
  223. parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits")
  224. args = parser.parse_args()
  225. # Init models
  226. model, enc = init_whisper(args.whisper_model_name)
  227. synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed)
  228. # Download tinyllama chat as a default model
  229. if args.llama_model is None:
  230. args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors")
  231. args.llama_gen = "tiny"
  232. args.llama_size = "1B-Chat"
  233. # Add 3 more tokens to the tokenizer
  234. if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer()
  235. tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
  236. llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize)
  237. toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path)
  238. # Start child process for mic input
  239. q = mp.Queue()
  240. is_listening_event = mp.Event()
  241. p = mp.Process(target=listener, args=(q, is_listening_event,))
  242. p.daemon = True
  243. p.start()
  244. # Start child process for speaker output
  245. out_q = mp.Queue()
  246. out_counter = mp.Value("i", 0)
  247. out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,))
  248. out_p.daemon = True
  249. out_p.start()
  250. # JIT tts
  251. for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
  252. tts(
  253. i, synth, hps, emotion_embedding,
  254. args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
  255. args.vits_noise_scale_w, args.vits_length_scale,
  256. args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
  257. )
  258. # Start the pipeline
  259. with log_writer() as log:
  260. while True:
  261. tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
  262. total = np.array([])
  263. out_counter.value = 0
  264. s = time.perf_counter()
  265. is_listening_event.set()
  266. prev_text = None
  267. while True:
  268. for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()])
  269. txt = transcribe_waveform(model, enc, [total], truncate=True)
  270. print(txt, end="\r")
  271. if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue
  272. if prev_text is not None and prev_text == txt:
  273. is_listening_event.clear()
  274. break
  275. prev_text = txt
  276. print() # to avoid llama printing on the same line
  277. log.append(f"{user_delim.capitalize()}: {txt}")
  278. # Generate with llama
  279. with Timing("llama generation: "):
  280. outputted, start_pos, response = llama_generate(
  281. llama, toks, outputted, txt, start_pos,
  282. user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature,
  283. max_tokens=args.llama_count
  284. )
  285. log.append(f"{resp_delim.capitalize()}: {response}")
  286. # Convert to voice
  287. with Timing("tts: "):
  288. sentences = nltk.sent_tokenize(response.replace('"', ""))
  289. for i in sentences:
  290. total = np.array([], dtype=np.int16)
  291. for j in chunks(i.split(), args.max_sentence_length):
  292. audio_data = tts(
  293. " ".join(j), synth, hps, emotion_embedding,
  294. args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
  295. args.vits_noise_scale_w, args.vits_length_scale,
  296. args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
  297. )
  298. total = np.concatenate([total, audio_data])
  299. out_q.put(total.tobytes())
  300. while out_counter.value < len(sentences): continue
  301. log.append(f"Total: {time.perf_counter() - s}")