Browse Source

Merge branch 'main' into githubactions

Alex Cheema 4 tháng trước cách đây
mục cha
commit
4d028608a7
2 tập tin đã thay đổi với 51 bổ sung16 xóa
  1. 31 16
      exo/api/chatgpt_api.py
  2. 20 0
      exo/topology/topology.py

+ 31 - 16
exo/api/chatgpt_api.py

@@ -182,6 +182,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
     cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
     cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
+    cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
 
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
@@ -232,11 +233,11 @@ class ChatGPTAPI:
             }
         )
         await response.prepare(request)
-        
+
         for model_name, pretty in pretty_name.items():
             if model_name in model_cards:
                 model_info = model_cards[model_name]
-                
+
                 if self.inference_engine_classname in model_info.get("repo", {}):
                     shard = build_base_shard(model_name, self.inference_engine_classname)
                     if shard:
@@ -244,11 +245,11 @@ class ChatGPTAPI:
                         downloader.current_shard = shard
                         downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
                         status = await downloader.get_shard_download_status()
-                        
+
                         download_percentage = status.get("overall") if status else None
                         total_size = status.get("total_size") if status else None
                         total_downloaded = status.get("total_downloaded") if status else False
-                        
+
                         model_data = {
                             model_name: {
                                 "name": pretty,
@@ -258,17 +259,17 @@ class ChatGPTAPI:
                                 "total_downloaded": total_downloaded
                             }
                         }
-                        
+
                         await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
-        
+
         await response.write(b"data: [DONE]\n\n")
         return response
-        
+
     except Exception as e:
         print(f"Error in handle_model_support: {str(e)}")
         traceback.print_exc()
         return web.json_response(
-            {"detail": f"Server error: {str(e)}"}, 
+            {"detail": f"Server error: {str(e)}"},
             status=500
         )
 
@@ -425,35 +426,35 @@ class ChatGPTAPI:
     try:
       model_name = request.match_info.get('model_name')
       if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
-      
+
       if not model_name or model_name not in model_cards:
         return web.json_response(
-          {"detail": f"Invalid model name: {model_name}"}, 
+          {"detail": f"Invalid model name: {model_name}"},
           status=400
           )
 
       shard = build_base_shard(model_name, self.inference_engine_classname)
       if not shard:
         return web.json_response(
-          {"detail": "Could not build shard for model"}, 
+          {"detail": "Could not build shard for model"},
           status=400
         )
 
       repo_id = get_repo(shard.model_id, self.inference_engine_classname)
       if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
-      
+
       # Get the HF cache directory using the helper function
       hf_home = get_hf_home()
       cache_dir = get_repo_root(repo_id)
-      
+
       if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
-      
+
       if os.path.exists(cache_dir):
         if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
         try:
           shutil.rmtree(cache_dir)
           return web.json_response({
-            "status": "success", 
+            "status": "success",
             "message": f"Model {model_name} deleted successfully",
             "path": str(cache_dir)
           })
@@ -465,7 +466,7 @@ class ChatGPTAPI:
         return web.json_response({
           "detail": f"Model files not found at {cache_dir}"
         }, status=404)
-            
+
     except Exception as e:
         print(f"Error in handle_delete_model: {str(e)}")
         traceback.print_exc()
@@ -543,6 +544,20 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"error": str(e)}, status=500)
 
+  async def handle_get_topology(self, request):
+    try:
+      topology = self.node.current_topology
+      if topology:
+        return web.json_response(topology.to_json())
+      else:
+        return web.json_response({})
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response(
+        {"detail": f"Error getting topology: {str(e)}"},
+        status=500
+      )
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()

+ 20 - 0
exo/topology/topology.py

@@ -51,3 +51,23 @@ class Topology:
     edges_str = ", ".join(f"{node}: {[f'{c.to_id}({c.description})' for c in conns]}"
                          for node, conns in self.peer_graph.items())
     return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"
+
+  def to_json(self):
+    return {
+      "nodes": {
+        node_id: capabilities.to_dict()
+        for node_id, capabilities in self.nodes.items()
+      },
+      "peer_graph": {
+        node_id: [
+          {
+            "from_id": conn.from_id,
+            "to_id": conn.to_id,
+            "description": conn.description
+          }
+          for conn in connections
+        ]
+        for node_id, connections in self.peer_graph.items()
+      },
+      "active_node_id": self.active_node_id
+    }