whisper.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. # thanks to https://github.com/openai/whisper for a good chunk of MIT licensed code
  2. import sys, base64, multiprocessing, itertools
  3. from typing import Optional, Union, Literal, List
  4. from tinygrad import Tensor, TinyJit, Variable, nn
  5. from tinygrad.nn.state import torch_load, load_state_dict
  6. from tinygrad.helpers import getenv, DEBUG, fetch
  7. import numpy as np
  8. import librosa
  9. class MultiHeadAttention:
  10. def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None):
  11. self.n_head = n_head
  12. self.query = nn.Linear(n_state, n_state)
  13. self.key = nn.Linear(n_state, n_state, bias=False)
  14. self.value = nn.Linear(n_state, n_state)
  15. self.out = nn.Linear(n_state, n_state)
  16. self.kv_caching = kv_caching
  17. self.max_self_attn_cache_len = max_self_attn_cache_len
  18. def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None):
  19. if self.kv_caching == 'cross':
  20. if xa is not None:
  21. k, v = self.key(xa), self.value(xa)
  22. if not hasattr(self, 'cache_k'):
  23. self.cache_k, self.cache_v = k, v
  24. else:
  25. self.cache_k.assign(k).realize()
  26. self.cache_v.assign(v).realize()
  27. else:
  28. k, v = self.cache_k, self.cache_v
  29. else:
  30. k, v = self.key(x), self.value(x)
  31. if self.kv_caching == 'self':
  32. if not hasattr(self, 'cache_k'):
  33. self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
  34. self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
  35. k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
  36. v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
  37. padding = self.max_self_attn_cache_len-len-x.shape[1]
  38. self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
  39. self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
  40. q = self.query(x)
  41. n_ctx = q.shape[1]
  42. assert(q.shape[-1] == k.shape[-1] == v.shape[-1])
  43. head_dim = q.shape[-1] // self.n_head
  44. q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
  45. k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
  46. v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
  47. attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None)
  48. wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
  49. return self.out(wv)
  50. class ResidualAttentionBlock:
  51. def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
  52. self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
  53. self.attn_ln = nn.LayerNorm(n_state)
  54. self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
  55. self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
  56. self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
  57. self.mlp_ln = nn.LayerNorm(n_state)
  58. def __call__(self, x, xa=None, mask=None, len: Union[Variable, int]=None):
  59. x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
  60. if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
  61. x = x + self.mlp_ln(x).sequential(self.mlp)
  62. return x.realize()
  63. class AudioEncoder:
  64. def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
  65. self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
  66. self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
  67. self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
  68. self.ln_post = nn.LayerNorm(n_audio_state)
  69. self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
  70. self.encode = TinyJit(self.__call__)
  71. def __call__(self, x):
  72. x = self.conv1(x).gelu()
  73. x = self.conv2(x).gelu()
  74. x = x.permute(0, 2, 1)
  75. x = x + self.positional_embedding[:x.shape[1]]
  76. x = x.sequential(self.blocks)
  77. x = self.ln_post(x)
  78. return x.realize()
  79. class TextDecoder:
  80. def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
  81. self.max_tokens_to_sample = n_text_ctx // 2
  82. self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
  83. self.token_embedding = nn.Embedding(n_vocab, n_text_state)
  84. self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
  85. self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
  86. self.ln = nn.LayerNorm(n_text_state)
  87. self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
  88. self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
  89. self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks]
  90. self.start_output_tok = TinyJit(self.output_tok)
  91. self.after_start_output_tok = TinyJit(self.output_tok)
  92. # if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only')
  93. def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
  94. seqlen = x.shape[-1]
  95. x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen]
  96. if pos == 0:
  97. for block in (self.blocks if streaming else self.blocks_start_tok):
  98. x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching
  99. return self.output_tok(x) if streaming else self.start_output_tok(x)
  100. else:
  101. for block in self.blocks_after_start_tok:
  102. len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos)
  103. x = block(x, mask=self.mask, len=len_v)
  104. return self.after_start_output_tok(x)
  105. def output_tok(self, x):
  106. return (self.ln(x) @ self.token_embedding.weight.T).realize()
  107. class Whisper:
  108. def __init__(self, dims, batch_size=1):
  109. self.encoder = AudioEncoder(**dims)
  110. self.decoder = TextDecoder(**dims)
  111. self.is_multilingual = dims["n_vocab"] == 51865
  112. self.batch_size = batch_size
  113. RATE = 16000
  114. SEGMENT_SECONDS=30
  115. SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000
  116. N_FFT = 400
  117. HOP_LENGTH = 160
  118. N_MELS = 80
  119. FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
  120. def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray:
  121. """
  122. :param waveforms: A list of possibly variable length 16000Hz audio samples
  123. :param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
  124. Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes
  125. :param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
  126. :return: mel spectrogram of the given waveforms
  127. """
  128. def pad_or_trim(arr, target_len):
  129. curr_len = len(arr)
  130. if curr_len == target_len:
  131. return arr
  132. elif curr_len < target_len:
  133. return np.pad(arr, (0, target_len - curr_len), 'constant')
  134. else:
  135. return arr[:target_len]
  136. max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
  137. if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r
  138. waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
  139. assert waveforms.shape[0] <= batch_size
  140. if waveforms.shape[0] < batch_size:
  141. # we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
  142. waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
  143. stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
  144. magnitudes = np.absolute(stft[..., :-1]) ** 2
  145. mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
  146. log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
  147. log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
  148. log_spec = (log_spec + 4.0) / 4.0
  149. return log_spec
  150. LANGUAGES = {
  151. "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
  152. "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese",
  153. "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian",
  154. "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu",
  155. "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
  156. "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili",
  157. "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian",
  158. "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole",
  159. "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy",
  160. "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
  161. }
  162. def get_encoding(encoding_name):
  163. with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
  164. ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
  165. n_vocab = len(ranks)
  166. specials = [
  167. "<|endoftext|>",
  168. "<|startoftranscript|>",
  169. *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
  170. "<|translate|>",
  171. "<|transcribe|>",
  172. "<|startoflm|>",
  173. "<|startofprev|>",
  174. "<|nospeech|>",
  175. "<|notimestamps|>",
  176. *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
  177. ]
  178. special_tokens = dict(zip(specials, itertools.count(n_vocab)))
  179. n_vocab += len(specials)
  180. import tiktoken
  181. return tiktoken.Encoding(
  182. name=encoding_name,
  183. explicit_n_vocab=n_vocab,
  184. pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
  185. mergeable_ranks=ranks,
  186. special_tokens=special_tokens)
  187. MODEL_URLS = {
  188. "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
  189. "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
  190. "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
  191. "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
  192. "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
  193. "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
  194. "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
  195. "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
  196. "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
  197. "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
  198. "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
  199. }
  200. def init_whisper(model_name="tiny.en", batch_size=1):
  201. assert MODEL_URLS[model_name] is not None
  202. filename = fetch(MODEL_URLS[model_name])
  203. state = torch_load(filename)
  204. model = Whisper(state['dims'], batch_size)
  205. load_state_dict(model, state['model_state_dict'], strict=False)
  206. enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
  207. return model, enc
  208. def load_file_waveform(filename):
  209. waveform, _ = librosa.load(filename, sr=RATE)
  210. return waveform
  211. def transcribe_file(model, enc, filename):
  212. return transcribe_waveform(model, enc, [load_file_waveform(filename)])
  213. def transcribe_waveform(model, enc, waveforms, truncate=False):
  214. """
  215. Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
  216. Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
  217. """
  218. N_audio = len(waveforms)
  219. log_spec = prep_audio(waveforms, model.batch_size, truncate)
  220. if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
  221. # we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
  222. # if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
  223. raise Exception("Multi-segment transcription not supported with batch audio input")
  224. start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
  225. if model.is_multilingual:
  226. # TODO detect language
  227. language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
  228. start_tokens.append(language_token)
  229. start_tokens.append(enc._special_tokens["<|transcribe|>"])
  230. start_tokens.append(enc._special_tokens["<|notimestamps|>"])
  231. transcription_start_index = len(start_tokens)
  232. eot = enc._special_tokens["<|endoftext|>"]
  233. transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
  234. for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
  235. encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
  236. pos = 0
  237. curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
  238. if curr_frame > 0:
  239. # pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
  240. prompt = np.concatenate((
  241. [enc._special_tokens["<|startofprev|>"]],
  242. transcription_tokens[0][-model.decoder.max_tokens_to_sample+1:],
  243. start_tokens))
  244. curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
  245. transcription_start_index = len(curr_segment_tokens[0])
  246. for i in range(model.decoder.max_tokens_to_sample):
  247. out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0)
  248. next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
  249. next_tokens[curr_segment_tokens[:, -1] == eot] = eot
  250. curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
  251. pos = curr_segment_tokens.shape[-1] - 1
  252. if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
  253. if (curr_segment_tokens[:, -1] == eot).all():
  254. break
  255. for i, t in enumerate(curr_segment_tokens):
  256. eot_index = np.where(t == eot)[0]
  257. eot_index = None if len(eot_index) == 0 else eot_index[0]
  258. transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))
  259. transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
  260. return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
  261. CHUNK = 1600
  262. RECORD_SECONDS = 10
  263. def listener(q):
  264. import pyaudio
  265. p = pyaudio.PyAudio()
  266. stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
  267. print("listening")
  268. for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
  269. data = stream.read(CHUNK)
  270. waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
  271. q.put(waveform)
  272. print("done listening")
  273. if __name__ == "__main__":
  274. model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
  275. if len(sys.argv) > 1:
  276. print(transcribe_file(model, enc, sys.argv[1]))
  277. else:
  278. # online
  279. q = multiprocessing.Queue()
  280. p = multiprocessing.Process(target=listener, args=(q,))
  281. p.daemon = True
  282. p.start()
  283. lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
  284. total = None
  285. did_read = False
  286. for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
  287. while not q.empty() or total is None:
  288. waveform = q.get()
  289. if total is None: total = waveform
  290. else: total = np.concatenate([total, waveform])
  291. did_read = True
  292. if did_read:
  293. log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True)
  294. encoded_audio = model.encoder.encode(Tensor(log_spec))
  295. # pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
  296. out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
  297. idx = int(out[0,-1].argmax().numpy().item())
  298. lst.append(idx)
  299. dec = enc.decode(lst)
  300. print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
  301. if dec.endswith("<|endoftext|>"):
  302. lst.pop()