瀏覽代碼

add a system prompt that makes it aware of exo cluster topology

Alex Cheema 9 月之前
父節點
當前提交
89965e9568
共有 3 個文件被更改,包括 28 次插入0 次删除
  1. 16 0
      exo/api/chatgpt_api.py
  2. 3 0
      exo/topology/device_capabilities.py
  3. 9 0
      exo/topology/topology.py

+ 16 - 0
exo/api/chatgpt_api.py

@@ -208,6 +208,22 @@ class ChatGPTAPI:
         tokenizer = await resolve_tokenizer(shard.model_id)
         if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
+        # Add system prompt with topology context
+        topology = self.node.current_topology
+        system_message = {
+            "role": "system",
+            "content": f"""
+            You are an AI assistant running on a distributed system called exo. The current topology of the system is:
+            {len(topology.nodes)} nodes:
+            {'\n'.join([f'{d.model} {d.memory_gb()}GB, {d.flops.fp16}TFLOPS (fp16)' for d in topology.nodes.values()])}
+            Total memory: {topology.total_memory_gb()}GB.
+            Total TFLOPS: {topology.total_tflops_fp16()}TFLOPS (fp16).
+            Please consider this information when processing requests.
+            Keep to once sentence responses, concise and friendly for a conversational voice output.
+            """
+        }
+        chat_request.messages.insert(0, system_message)
+
         prompt = build_prompt(tokenizer, chat_request.messages)
         callback_id = f"chatgpt-api-wait-response-{request_id}"
         callback = self.node.on_token.register(callback_id)

+ 3 - 0
exo/topology/device_capabilities.py

@@ -32,6 +32,9 @@ class DeviceCapabilities:
         if isinstance(self.flops, dict):
             self.flops = DeviceFlops(**self.flops)
 
+    def memory_gb(self) -> float:
+        return round(self.memory / 1024, 2)
+
     def to_dict(self):
         return {
             'model': self.model,

+ 9 - 0
exo/topology/topology.py

@@ -35,6 +35,15 @@ class Topology:
                     edges.append((node, neighbor))
         return edges
 
+    def total_memory(self) -> int:
+        return sum([node.memory for node in self.nodes.values()])
+
+    def total_memory_gb(self) -> int:
+        return self.total_memory() / 1024
+
+    def total_tflops_fp16(self) -> float:
+        return round(sum([node.flops.fp16 for node in self.nodes.values()]), 2)
+
     def merge(self, other: 'Topology'):
         for node_id, capabilities in other.nodes.items():
             self.update_node(node_id, capabilities)