| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import unittest
- import torch
- import tqdm
- import torchaudio
- import pathlib
- import jiwer
- import os
- import numpy as np
- from whisper.normalizers import EnglishTextNormalizer
- from examples.whisper import init_whisper, transcribe_waveform
- class TestWhisperLibriSpeech(unittest.TestCase):
- # reference WERs determined by running https://github.com/openai/whisper/blob/main/notebooks/LibriSpeech.ipynb
- # the values should be consistent with the paper D.1.1 https://cdn.openai.com/papers/whisper.pdf#page=22
- # tinygrad WERs do not perfectly match due to what seem to be precision differences vs torch
- def test_en_tiny(self):
- run_evaluation("tiny.en", 0.056629001883239174, 0.05655609406528749)
- def test_tiny(self):
- run_evaluation("tiny", 0.0771121409407306, 0.07558413638335187)
- def test_en_base(self):
- run_evaluation("base.en", 0.041412520064205455, 0.04271408904897505)
- def test_en_small(self):
- run_evaluation("small.en", 0.03369011117172363, 0.030531615969223228)
- def run_evaluation(model_name, tinygrad_expected_wer, reference_wer):
- dataset = LibriSpeech()
- batch_size=16
- loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
- model, enc = init_whisper(model_name, batch_size=batch_size)
- hypotheses = []
- references = []
- for audio, texts in tqdm.tqdm(loader):
- transcriptions = transcribe_waveform(model, enc, audio.numpy(), truncate=True)
- hypotheses.extend(transcriptions)
- references.extend(texts)
- normalizer = EnglishTextNormalizer()
- normalized_hypotheses = [normalizer(text) for text in hypotheses]
- normalized_references = [normalizer(text) for text in references]
- wer = jiwer.wer(normalized_hypotheses, normalized_references)
- np.testing.assert_almost_equal(wer, tinygrad_expected_wer)
- print(f'tinygrad WER {wer} vs reference WER {reference_wer}')
- del model, enc
- class LibriSpeech(torch.utils.data.Dataset):
- def __init__(self):
- folder = pathlib.Path(__file__).parent.parent.parent / "extra" / "datasets" / "librispeech"
- if not os.path.exists(folder):
- os.makedirs(folder)
- self.dataset = torchaudio.datasets.LIBRISPEECH(
- root=folder,
- url="test-clean",
- download=True,
- )
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self, item):
- audio, sample_rate, text, _, _, _ = self.dataset[item]
- assert sample_rate == 16000
- return pad_or_trim_tensor(audio[0]), text
- def pad_or_trim_tensor(tensor, target_len=480000):
- curr_len = len(tensor)
- if curr_len == target_len:
- return tensor
- elif curr_len < target_len:
- return torch.cat((tensor, torch.zeros(target_len - curr_len)))
- else:
- return tensor[:target_len]
- if __name__ == '__main__':
- unittest.main()
|