| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- from tinygrad.tensor import Tensor
- from tinygrad.engine.jit import TinyJit
- from tinygrad.nn import Linear, Embedding
- from tinygrad.helpers import fetch
- import numpy as np
- from pathlib import Path
- class RNNT:
- def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32):
- self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout)
- self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
- self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout)
- @TinyJit
- def __call__(self, x, y, hc=None):
- f, _ = self.encoder(x, None)
- g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False))
- out = self.joint(f, g)
- return out.realize()
- def decode(self, x, x_lens):
- logits, logit_lens = self.encoder(x, x_lens)
- outputs = []
- for b in range(logits.shape[0]):
- inseq = logits[b, :, :].unsqueeze(1)
- logit_len = logit_lens[b]
- seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item()))
- outputs.append(seq)
- return outputs
- def _greedy_decode(self, logits, logit_len):
- hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
- labels = []
- label = Tensor.zeros(1, 1, requires_grad=False)
- mask = Tensor.zeros(1, requires_grad=False)
- for time_idx in range(logit_len):
- logit = logits[time_idx, :, :].unsqueeze(0)
- not_blank = True
- added = 0
- while not_blank and added < 30:
- if len(labels) > 0:
- mask = (mask + 1).clip(0, 1)
- label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
- jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
- k = jhc[0, 0, :29].argmax(axis=0).numpy()
- not_blank = k != 28
- if not_blank:
- labels.append(k)
- hc = jhc[:, :, 29:] + 1 - 1
- added += 1
- return labels
- @TinyJit
- def _pred_joint(self, logit, label, hc, mask):
- g, hc = self.prediction(label, hc, mask)
- j = self.joint(logit, g)[0]
- j = j.pad(((0, 1), (0, 1), (0, 0)))
- out = j.cat(hc, dim=2)
- return out.realize()
- def load_from_pretrained(self):
- fn = Path(__file__).parents[1] / "weights/rnnt.pt"
- fetch("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
- import torch
- with open(fn, "rb") as f:
- state_dict = torch.load(f, map_location="cpu")["state_dict"]
- # encoder
- for i in range(2):
- self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy())
- self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy())
- self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy())
- self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy())
- for i in range(3):
- self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy())
- self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy())
- self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy())
- self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy())
- # prediction
- self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
- for i in range(2):
- self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy())
- self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy())
- self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy())
- self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy())
- # joint
- self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
- self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy())
- self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy())
- self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy())
- class LSTMCell:
- def __init__(self, input_size, hidden_size, dropout):
- self.dropout = dropout
- self.weights_ih = Tensor.uniform(hidden_size * 4, input_size)
- self.bias_ih = Tensor.uniform(hidden_size * 4)
- self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size)
- self.bias_hh = Tensor.uniform(hidden_size * 4)
- def __call__(self, x, hc):
- gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh)
- i, f, g, o = gates.chunk(4, 1)
- i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
- c = (f * hc[x.shape[0]:]) + (i * g)
- h = (o * c.tanh()).dropout(self.dropout)
- return Tensor.cat(h, c).realize()
- class LSTM:
- def __init__(self, input_size, hidden_size, layers, dropout):
- self.input_size = input_size
- self.hidden_size = hidden_size
- self.layers = layers
- self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)]
- def __call__(self, x, hc):
- @TinyJit
- def _do_step(x_, hc_):
- return self.do_step(x_, hc_)
- if hc is None:
- hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
- output = None
- for t in range(x.shape[0]):
- hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this?
- if output is None:
- output = hc[-1:, :x.shape[1]]
- else:
- output = output.cat(hc[-1:, :x.shape[1]], dim=0).realize()
- return output, hc
- def do_step(self, x, hc):
- new_hc = [x]
- for i, cell in enumerate(self.cells):
- new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i]))
- return Tensor.stack(*new_hc[1:]).realize()
- class StackTime:
- def __init__(self, factor):
- self.factor = factor
- def __call__(self, x, x_lens):
- x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0)))
- x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor)
- return x, x_lens / self.factor if x_lens is not None else None
- class Encoder:
- def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout):
- self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
- self.stack_time = StackTime(stack_time_factor)
- self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout)
- def __call__(self, x, x_lens):
- x, _ = self.pre_rnn(x, None)
- x, x_lens = self.stack_time(x, x_lens)
- x, _ = self.post_rnn(x, None)
- return x.transpose(0, 1), x_lens
- class Prediction:
- def __init__(self, vocab_size, hidden_size, layers, dropout):
- self.hidden_size = hidden_size
- self.emb = Embedding(vocab_size - 1, hidden_size)
- self.rnn = LSTM(hidden_size, hidden_size, layers, dropout)
- def __call__(self, x, hc, m):
- emb = self.emb(x) * m
- x_, hc = self.rnn(emb.transpose(0, 1), hc)
- return x_.transpose(0, 1), hc
- class Joint:
- def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout):
- self.dropout = dropout
- self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
- self.l2 = Linear(joint_hidden_size, vocab_size)
- def __call__(self, f, g):
- (_, T, H), (B, U, H2) = f.shape, g.shape
- f = f.unsqueeze(2).expand(B, T, U, H)
- g = g.unsqueeze(1).expand(B, T, U, H2)
- inp = f.cat(g, dim=3)
- t = self.l1(inp).relu()
- t = t.dropout(self.dropout)
- return self.l2(t)
|