Преглед на файлове

clean up DEBUG=2 logs, a few fixes for token

Alex Cheema преди 6 месеца
родител
ревизия
55d1846f5e
променени са 4 файла, в които са добавени 14 реда и са изтрити 14 реда
  1. 1 1
      exo/download/hf/hf_helpers.py
  2. 2 1
      exo/download/hf/hf_shard_download.py
  3. 5 5
      exo/main.py
  4. 6 7
      exo/orchestration/node.py

+ 1 - 1
exo/download/hf/hf_helpers.py

@@ -441,7 +441,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
     shard_specific_patterns = set(["*.safetensors"])
-  if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
+  if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
 
 async def get_file_download_percentage(

+ 2 - 1
exo/download/hf/hf_shard_download.py

@@ -159,13 +159,14 @@ class HFShardDownloader(ShardDownloader):
           print(f"Download calculation for {self.current_repo_id}:")
           print(f"Total bytes: {total_bytes}")
           print(f"Downloaded bytes: {downloaded_bytes}")
+        if DEBUG >= 3:
           for file in relevant_files:
             print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
 
       return status
 
     except Exception as e:
-      if DEBUG >= 2:
+      if DEBUG >= 3:
         print(f"Error getting shard download status: {e}")
         traceback.print_exc()
       return None

+ 5 - 5
exo/main.py

@@ -187,10 +187,10 @@ api = ChatGPTAPI(
   system_prompt=args.system_prompt
 )
 buffered_token_output = {}
-def update_topology_viz(req_id, token, __):
+def update_topology_viz(req_id, tokens, __):
   if not topology_viz: return
-  if req_id in buffered_token_output: buffered_token_output[req_id].append(token)
-  else: buffered_token_output[req_id] = [token]
+  if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
+  else: buffered_token_output[req_id] = tokens
 
   if inference_engine.shard.model_id != 'stable-diffusion-2-1-base':
     topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
@@ -243,8 +243,8 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     await node.process_prompt(shard, prompt, request_id=request_id)
 
     tokens = []
-    def on_token(_request_id, _token, _is_finished):
-      tokens.append(_token)
+    def on_token(_request_id, _tokens, _is_finished):
+      tokens.extend(_tokens)
       return _request_id == request_id and _is_finished
     await callback.wait(on_token, timeout=300)
 

+ 6 - 7
exo/orchestration/node.py

@@ -47,7 +47,7 @@ class Node:
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self.default_sample_temperature = default_sample_temperature
-    self._on_token = AsyncCallbackSystem[str, Tuple[str, int, bool]]()
+    self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
@@ -130,9 +130,8 @@ class Node:
         self.buffered_token_output[request_id][0].append(token.item())
         is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-        asyncio.create_task(self.broadcast_result(request_id, [self.buffered_token_output[request_id][0][-1]], is_finished))
         forward = token.reshape(1, -1)
-        intermediate_result = self.buffered_token_output[request_id][0][-1]
+        intermediate_result = [self.buffered_token_output[request_id][0][-1]]
       else:
         forward = result
     else:
@@ -575,16 +574,16 @@ class Node:
     return self.topology
 
   @property
-  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, int, bool]]:
+  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
     return self._on_token
 
   @property
   def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
     return self._on_opaque_status
 
-  def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None:
-    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}")
-    self.on_token.trigger_all(request_id, token, is_finished)
+  def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
+    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
+    self.on_token.trigger_all(request_id, tokens, is_finished)
   
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     async def send_result_to_peer(peer):