|
@@ -70,25 +70,28 @@ class Node:
|
|
|
def on_node_status(self, request_id, opaque_status):
|
|
|
try:
|
|
|
status_data = json.loads(opaque_status)
|
|
|
- if status_data.get("type", "") == "supported_inference_engines":
|
|
|
+ status_type = status_data.get("type", "")
|
|
|
+ if status_type == "supported_inference_engines":
|
|
|
node_id = status_data.get("node_id")
|
|
|
engines = status_data.get("engines", [])
|
|
|
self.topology_inference_engines_pool.append(engines)
|
|
|
- if status_data.get("type", "") == "node_status":
|
|
|
+ elif status_type == "node_status":
|
|
|
if status_data.get("status", "").startswith("start_"):
|
|
|
self.current_topology.active_node_id = status_data.get("node_id")
|
|
|
elif status_data.get("status", "").startswith("end_"):
|
|
|
if status_data.get("node_id") == self.current_topology.active_node_id:
|
|
|
self.current_topology.active_node_id = None
|
|
|
+
|
|
|
download_progress = None
|
|
|
- if status_data.get("type", "") == "download_progress":
|
|
|
+ if status_type == "download_progress":
|
|
|
if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
|
|
|
download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
|
|
|
self.node_download_progress[status_data.get('node_id')] = download_progress
|
|
|
+
|
|
|
if self.topology_viz:
|
|
|
self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
|
|
|
except Exception as e:
|
|
|
- if DEBUG >= 1: print(f"Error updating visualization: {e}")
|
|
|
+ if DEBUG >= 1: print(f"Error on_node_status: {e}")
|
|
|
if DEBUG >= 1: traceback.print_exc()
|
|
|
|
|
|
def get_supported_inference_engines(self):
|
|
@@ -153,10 +156,39 @@ class Node:
|
|
|
request_id: Optional[str] = None,
|
|
|
) -> None:
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
+ asyncio.create_task(
|
|
|
+ self.broadcast_opaque_status(
|
|
|
+ request_id,
|
|
|
+ json.dumps({
|
|
|
+ "type": "node_status",
|
|
|
+ "node_id": self.id,
|
|
|
+ "status": "start_process_prompt",
|
|
|
+ "base_shard": base_shard.to_dict(),
|
|
|
+ "shard": shard.to_dict(),
|
|
|
+ "prompt": prompt,
|
|
|
+ "request_id": request_id,
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ )
|
|
|
start_time = time.perf_counter_ns()
|
|
|
await self._process_prompt(base_shard, prompt, request_id)
|
|
|
end_time = time.perf_counter_ns()
|
|
|
elapsed_time_ns = end_time - start_time
|
|
|
+ asyncio.create_task(
|
|
|
+ self.broadcast_opaque_status(
|
|
|
+ request_id,
|
|
|
+ json.dumps({
|
|
|
+ "type": "node_status",
|
|
|
+ "node_id": self.id,
|
|
|
+ "status": "end_process_prompt",
|
|
|
+ "base_shard": base_shard.to_dict(),
|
|
|
+ "shard": shard.to_dict(),
|
|
|
+ "prompt": prompt,
|
|
|
+ "request_id": request_id,
|
|
|
+ "elapsed_time_ns": elapsed_time_ns,
|
|
|
+ }),
|
|
|
+ )
|
|
|
+ )
|
|
|
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
|
|
|
|
|
|
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|