Procházet zdrojové kódy

This doesn't need to be a tuple really

Nel Nibcord před 5 měsíci
rodič
revize
bf33ffde87
1 změnil soubory, kde provedl 5 přidání a 6 odebrání
  1. 5 6
      exo/orchestration/standard_node.py

+ 5 - 6
exo/orchestration/standard_node.py

@@ -39,8 +39,8 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.buffered_inputs: Dict[str, Tuple[List[np.ndarray], bool]] = {}
-    self.buffered_logits: Dict[str, Tuple[List[np.ndarray], bool]] = {}
+    self.buffered_logits: Dict[str, List[np.ndarray]] = {}
+    self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
@@ -117,13 +117,13 @@ class StandardNode(Node):
       self.buffered_token_output[request_id] = ([], False)
     
     if request_id not in self.buffered_logits:
-      self.buffered_logits[request_id] = ([], False)
+      self.buffered_logits[request_id] = []
 
-    self.buffered_logits[request_id][0] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
+    self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
 
     if shard.is_last_layer():
       result = await self.inference_engine.sample(result)
-      inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id][0]) + 1})
+      inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id]) + 1})
     
     await self.inference_engine.ensure_shard(shard)
     is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
@@ -138,7 +138,6 @@ class StandardNode(Node):
 
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      self.buffered_logits[request_id] = (self.buffered_logits[request_id][0], True)
     else:
       asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))