external_test_whisper_librispeech.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import unittest
  2. import torch
  3. import tqdm
  4. import torchaudio
  5. import pathlib
  6. import jiwer
  7. import os
  8. import numpy as np
  9. from whisper.normalizers import EnglishTextNormalizer
  10. from examples.whisper import init_whisper, transcribe_waveform
  11. class TestWhisperLibriSpeech(unittest.TestCase):
  12. # reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
  13. # the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
  14. # tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
  15. def test_en_tiny(self):
  16. run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)
  17. def test_tiny(self):
  18. run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)
  19. def test_en_base(self):
  20. run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)
  21. def test_en_small(self):
  22. run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)
  23. def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
  24. dataset = LibriSpeech()
  25. batch_size=16
  26. loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
  27. model, enc = init_whisper(model_name, batch_size=batch_size)
  28. hypotheses = []
  29. references = []
  30. for audio, texts in tqdm.tqdm(loader):
  31. transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True)
  32. hypotheses.extend(transcriptions)
  33. references.extend(texts)
  34. normalizer = EnglishTextNormalizer()
  35. normalized_hypotheses = [normalizer(text) for text in hypotheses]
  36. normalized_references = [normalizer(text) for text in references]
  37. wer = jiwer.wer(normalized_hypotheses, normalized_references)
  38. np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
  39. print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
  40. del model, enc
  41. class LibriSpeech(torch.utils.data.Dataset):
  42. def __init__(self):
  43. folder = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
  44. if not os.path.exists(folder):
  45. os.makedirs(folder)
  46. self.dataset = torchaudio.datasets.LIBRISPEECH(
  47. root=folder,
  48. url="test-clean",
  49. download=True,
  50. )
  51. def __len__(self):
  52. return len(self.dataset)
  53. def __getitem__(self, item):
  54. audio, sample_rate, text, _, _, _ = self.dataset[item]
  55. assert sample_rate == 16000
  56. return pad_or_trim_tensor(audio[0]), text
  57. def pad_or_trim_tensor(tensor, target_len=480000):
  58. curr_len = len(tensor)
  59. if curr_len == target_len:
  60. return tensor
  61. elif curr_len < target_len:
  62. return torch.cat((tensor, torch.zeros(target_len - curr_len)))
  63. else:
  64. return tensor[:target_len]
  65. if __name__ == '__main__':
  66. unittest.main()