|
@@ -20,34 +20,3 @@ def make_prompt_state(x, model, shard):
|
|
|
cache = [create_kv_cache(x, model.layers[i].attention.max_context, model.layers[i].attention.n_kv_heads, model.layers[i].attention.head_dim) for i in range(shard.start_layer, shard.end_layer + 1)]
|
|
|
|
|
|
return ModelState(cache)
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-class StatefulModel:
|
|
|
- def __init__(self, model, max_states: int = 2):
|
|
|
- super().__init__()
|
|
|
- self.model = model
|
|
|
- self.max_states = max_states
|
|
|
- self.states = OrderedDict()
|
|
|
-
|
|
|
- def init_cache(self, x: Tensor, request_id: str):
|
|
|
- if len(self.states) >= self.max_states:
|
|
|
- self.states.popitem(last=False)
|
|
|
-
|
|
|
- self.states[request_id] = make_prompt_state(self.model, self.model.shard)
|
|
|
-
|
|
|
- def __call__(self, x: Tensor, request_id: Optional[str] = None, use_cache: bool = True):
|
|
|
- h = self.model.embed(x)
|
|
|
- #print(f"StatefulModel in <- {h}")
|
|
|
- if use_cache and request_id is not None:
|
|
|
- if request_id not in self.states:
|
|
|
- self.init_cache(h, request_id)
|
|
|
- else:
|
|
|
- self.states.move_to_end(request_id)
|
|
|
- out = self.model.forward(h, self.states[request_id].start, cache=self.states[request_id].cache)
|
|
|
- self.states[request_id].start += h.shape[1]
|
|
|
- else:
|
|
|
- out = self.model.forward(h, 0)
|
|
|
- #print(f"StatefulModel out -> {out}")
|
|
|
- return out
|
|
|
-
|