|
@@ -225,9 +225,9 @@ class Transformer:
|
|
|
h = inputs
|
|
|
return h
|
|
|
|
|
|
- def forward(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
|
|
|
- if x.shape[0:2] == (1, 1) and self.forward_jit is not None:
|
|
|
- return self.forward_jit(x, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
|
|
|
+ def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
|
|
|
+ if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
|
|
|
+ return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
|
|
|
return self.forward_base(x, start_pos, cache=cache)
|
|
|
|
|
|
def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
|