|
@@ -235,7 +235,46 @@ class Transformer:
|
|
|
h = self.embed(x)
|
|
|
return self.forward(h, start_pos, cache=cache)
|
|
|
|
|
|
+class TransformerShard:
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ shard: Shard,
|
|
|
+ base,
|
|
|
+ jit: bool = True,
|
|
|
+ ):
|
|
|
+ shardrange = range(shard.start_layer, shard.end_layer + 1)
|
|
|
+ self.layers = [layer for layer, n in zip(base.layers, range(shard.n_layers)) if n in shardrange]
|
|
|
+ self.norm = base.norm
|
|
|
+ self.tok_embeddings = base.tok_embeddings
|
|
|
+ self.embed = (lambda x: self.tok_embeddings(x)) if shard.is_first_layer() else (lambda x: x)
|
|
|
+ self.output = base.output
|
|
|
+ self.post = (lambda x: self.output(x)) if shard.is_last_layer() else (lambda x: x)
|
|
|
+ self.max_context = base.max_context
|
|
|
+ self.null_cache = [None for _ in shardrange]
|
|
|
+ self.freqs_cis = base.freqs_cis
|
|
|
+ self.forward_jit = TinyJit(self.forward_base) if jit else None
|
|
|
+
|
|
|
+ def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache):
|
|
|
+ seqlen = x.shape[1]
|
|
|
+ freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
|
|
|
+ mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
|
|
|
+
|
|
|
+ for layer, c in zip(self.layers, cache):
|
|
|
+ x = layer(x, start_pos, freqs_cis, mask, cache=c)
|
|
|
|
|
|
+ out = self.post(x)
|
|
|
+ return out
|
|
|
+
|
|
|
+ def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
|
|
|
+ if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
|
|
|
+ return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
|
|
|
+ return self.forward_base(x, start_pos, cache=cache)
|
|
|
+
|
|
|
+ def __call__(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
|
|
|
+ # TODO: better way to handle the first call v.s. the rest?
|
|
|
+ h = self.embed(x)
|
|
|
+ return self.forward(h, start_pos, cache=self.null_cache if cache is None else cache)
|
|
|
+
|
|
|
# *** helpers ***
|
|
|
|
|
|
|