|
@@ -15,7 +15,7 @@ from exo.helpers import AsyncCallbackSystem
|
|
|
from exo.viz.topology_viz import TopologyViz
|
|
|
|
|
|
class StandardNode(Node):
|
|
|
- def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None):
|
|
|
+ def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None, disable_tui: Optional[bool] = False):
|
|
|
self.id = id
|
|
|
self.inference_engine = inference_engine
|
|
|
self.server = server
|
|
@@ -25,7 +25,7 @@ class StandardNode(Node):
|
|
|
self.topology: Topology = Topology()
|
|
|
self.device_capabilities = device_capabilities()
|
|
|
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
|
|
|
- self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url)
|
|
|
+ self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) if not disable_tui else None
|
|
|
self.max_generate_tokens = max_generate_tokens
|
|
|
self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
|
|
|
self._on_opaque_status = AsyncCallbackSystem[str, str]()
|
|
@@ -40,7 +40,8 @@ class StandardNode(Node):
|
|
|
elif status_data.get("status", "").startswith("end_"):
|
|
|
if status_data.get("node_id") == self.current_topology.active_node_id:
|
|
|
self.current_topology.active_node_id = None
|
|
|
- self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
|
|
|
+ if self.topology_viz:
|
|
|
+ self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
|
|
|
except json.JSONDecodeError:
|
|
|
pass
|
|
|
|
|
@@ -242,7 +243,8 @@ class StandardNode(Node):
|
|
|
|
|
|
next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
|
|
|
self.topology = next_topology
|
|
|
- self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
|
|
|
+ if self.topology_viz:
|
|
|
+ self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
|
|
|
return next_topology
|
|
|
|
|
|
# TODO: unify this and collect_topology as global actions
|