|
@@ -11,11 +11,12 @@ class StatefulShardedModel:
|
|
|
def __init__(self, shard: Shard, model: nn.Module):
|
|
|
self.shard = shard
|
|
|
self.model = model
|
|
|
- self.reset()
|
|
|
+ self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
|
|
|
|
|
|
def step(
|
|
|
self,
|
|
|
- y,
|
|
|
+ request_id: str,
|
|
|
+ x,
|
|
|
pixel_values=None,
|
|
|
temp: float = 0.0,
|
|
|
top_p: float = 1.0,
|
|
@@ -37,11 +38,15 @@ class StatefulShardedModel:
|
|
|
|
|
|
return token
|
|
|
|
|
|
- # TODO : revert hacky fix
|
|
|
+ y = x
|
|
|
+
|
|
|
+ if request_id not in self.request_cache:
|
|
|
+ self.init_cache(request_id)
|
|
|
+
|
|
|
if pixel_values is None:
|
|
|
- output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
|
|
|
+ output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
|
|
|
else:
|
|
|
- output = self.model(y, pixel_values=pixel_values, cache=self.cache)
|
|
|
+ output = self.model(y, pixel_values=pixel_values, cache=self.request_cache[request_id])
|
|
|
|
|
|
if self.shard.is_last_layer():
|
|
|
logits = output[:, -1, :]
|
|
@@ -59,10 +64,10 @@ class StatefulShardedModel:
|
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
|
return self.step(x, temp, top_p, logit_bias)
|
|
|
|
|
|
- def reset(self):
|
|
|
+ def init_cache(self, request_id: str):
|
|
|
kv_heads = (
|
|
|
[self.model.n_kv_heads] * len(self.model.layers)
|
|
|
if isinstance(self.model.n_kv_heads, int)
|
|
|
else self.model.n_kv_heads
|
|
|
)
|
|
|
- self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
|
|
|
+ self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]
|