Varshith 1 éve
szülő
commit
54993995dc

+ 12 - 7
exo/inference/mlx/sharded_model.py

@@ -11,11 +11,12 @@ class StatefulShardedModel:
     def __init__(self, shard: Shard, model: nn.Module):
         self.shard = shard
         self.model = model
-        self.reset()
+        self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
 
     def step(
         self,
-        y,
+        request_id: str,
+        x,
         pixel_values=None,
         temp: float = 0.0,
         top_p: float = 1.0,
@@ -37,11 +38,15 @@ class StatefulShardedModel:
 
             return token
 
-        # TODO : revert hacky fix
+        y = x
+
+        if request_id not in self.request_cache:
+            self.init_cache(request_id)
+
         if pixel_values is None:
-            output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
+            output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
         else:
-            output = self.model(y, pixel_values=pixel_values, cache=self.cache)
+            output = self.model(y, pixel_values=pixel_values, cache=self.request_cache[request_id])
 
         if self.shard.is_last_layer():
             logits = output[:, -1, :]
@@ -59,10 +64,10 @@ class StatefulShardedModel:
     ) -> Generator[Tuple[mx.array, mx.array], None, None]:
         return self.step(x, temp, top_p, logit_bias)
 
-    def reset(self):
+    def init_cache(self, request_id: str):
         kv_heads = (
             [self.model.n_kv_heads] * len(self.model.layers)
             if isinstance(self.model.n_kv_heads, int)
             else self.model.n_kv_heads
         )
-        self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+        self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]

+ 6 - 6
exo/inference/mlx/test_sharded_llava.py

@@ -34,11 +34,11 @@ pixel_values = mx.array(inputs["pixel_values"])
 input_ids = mx.array(inputs["input_ids"])
 
 print(prompt)
-y = full.step(input_ids, pixel_values, temp=0)
+y = full.step("full", input_ids, pixel_values, temp=0)
 full_generated_tokens = [y.item()]
 
 for _ in range(13):
-    y = full.step(y, temp=0)
+    y = full.step("full", y, temp=0)
     full_generated_tokens.append(y.item())
 
 full_response = full_processor.tokenizer.decode(full_generated_tokens)
@@ -48,13 +48,13 @@ inputs = processor1(prompt, img, return_tensors="np")
 pixel_values = mx.array(inputs["pixel_values"])
 input_ids = mx.array(inputs["input_ids"])
 
-y = m1.step(input_ids, pixel_values, temp=0)
-y = m2.step(y, temp=0)
+y = m1.step("shard", input_ids, pixel_values, temp=0)
+y = m2.step("shard", y, temp=0)
 full_generated_tokens = [y.item()]
 
 for _ in range(13):
-    y = m1.step(y, temp=0)
-    y = m2.step(y, temp=0)
+    y = m1.step("shard", y, temp=0)
+    y = m2.step("shard", y, temp=0)
     full_generated_tokens.append(y.item())
 
 sharded_response = processor2.tokenizer.decode(full_generated_tokens)