Pārlūkot izejas kodu

dont broadcast every single process_tensor

Alex Cheema 5 mēneši atpakaļ
vecāks
revīzija
b17faa8199
1 mainītis faili ar 2 papildinājumiem un 31 dzēšanām
  1. 2 31
      exo/orchestration/node.py

+ 2 - 31
exo/orchestration/node.py

@@ -156,6 +156,7 @@ class Node:
     request_id: Optional[str] = None,
   ) -> None:
     shard = self.get_current_shard(base_shard)
+    start_time = time.perf_counter_ns()
     asyncio.create_task(
       self.broadcast_opaque_status(
         request_id,
@@ -170,7 +171,6 @@ class Node:
         }),
       )
     )
-    start_time = time.perf_counter_ns()
     await self._process_prompt(base_shard, prompt, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
@@ -351,39 +351,11 @@ class Node:
     request_id: Optional[str] = None,
   ) -> None:
     shard = self.get_current_shard(base_shard)
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "start_process_tensor",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "tensor_size": tensor.size,
-          "tensor_shape": tensor.shape,
-          "request_id": request_id,
-        }),
-      )
-    )
     start_time = time.perf_counter_ns()
     await self._process_tensor(shard, tensor, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "end_process_tensor",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "request_id": request_id,
-          "elapsed_time_ns": elapsed_time_ns,
-        }),
-      )
-    )
+    if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
 
   async def _process_tensor(
     self,
@@ -395,7 +367,6 @@ class Node:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
 
-    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
       self.outstanding_requests[request_id] = "processing"
       result = await self.inference_engine.infer_tensor(request_id, shard, tensor)