|
@@ -182,6 +182,7 @@ class ChatGPTAPI:
|
|
|
cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
|
|
|
+ cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
|
|
|
|
|
if "__compiled__" not in globals():
|
|
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
@@ -232,11 +233,11 @@ class ChatGPTAPI:
|
|
|
}
|
|
|
)
|
|
|
await response.prepare(request)
|
|
|
-
|
|
|
+
|
|
|
for model_name, pretty in pretty_name.items():
|
|
|
if model_name in model_cards:
|
|
|
model_info = model_cards[model_name]
|
|
|
-
|
|
|
+
|
|
|
if self.inference_engine_classname in model_info.get("repo", {}):
|
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
if shard:
|
|
@@ -244,11 +245,11 @@ class ChatGPTAPI:
|
|
|
downloader.current_shard = shard
|
|
|
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
status = await downloader.get_shard_download_status()
|
|
|
-
|
|
|
+
|
|
|
download_percentage = status.get("overall") if status else None
|
|
|
total_size = status.get("total_size") if status else None
|
|
|
total_downloaded = status.get("total_downloaded") if status else False
|
|
|
-
|
|
|
+
|
|
|
model_data = {
|
|
|
model_name: {
|
|
|
"name": pretty,
|
|
@@ -258,17 +259,17 @@ class ChatGPTAPI:
|
|
|
"total_downloaded": total_downloaded
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
|
-
|
|
|
+
|
|
|
await response.write(b"data: [DONE]\n\n")
|
|
|
return response
|
|
|
-
|
|
|
+
|
|
|
except Exception as e:
|
|
|
print(f"Error in handle_model_support: {str(e)}")
|
|
|
traceback.print_exc()
|
|
|
return web.json_response(
|
|
|
- {"detail": f"Server error: {str(e)}"},
|
|
|
+ {"detail": f"Server error: {str(e)}"},
|
|
|
status=500
|
|
|
)
|
|
|
|
|
@@ -425,35 +426,35 @@ class ChatGPTAPI:
|
|
|
try:
|
|
|
model_name = request.match_info.get('model_name')
|
|
|
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
|
|
|
-
|
|
|
+
|
|
|
if not model_name or model_name not in model_cards:
|
|
|
return web.json_response(
|
|
|
- {"detail": f"Invalid model name: {model_name}"},
|
|
|
+ {"detail": f"Invalid model name: {model_name}"},
|
|
|
status=400
|
|
|
)
|
|
|
|
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
if not shard:
|
|
|
return web.json_response(
|
|
|
- {"detail": "Could not build shard for model"},
|
|
|
+ {"detail": "Could not build shard for model"},
|
|
|
status=400
|
|
|
)
|
|
|
|
|
|
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
|
|
|
-
|
|
|
+
|
|
|
# Get the HF cache directory using the helper function
|
|
|
hf_home = get_hf_home()
|
|
|
cache_dir = get_repo_root(repo_id)
|
|
|
-
|
|
|
+
|
|
|
if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
|
|
|
-
|
|
|
+
|
|
|
if os.path.exists(cache_dir):
|
|
|
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
|
|
|
try:
|
|
|
shutil.rmtree(cache_dir)
|
|
|
return web.json_response({
|
|
|
- "status": "success",
|
|
|
+ "status": "success",
|
|
|
"message": f"Model {model_name} deleted successfully",
|
|
|
"path": str(cache_dir)
|
|
|
})
|
|
@@ -465,7 +466,7 @@ class ChatGPTAPI:
|
|
|
return web.json_response({
|
|
|
"detail": f"Model files not found at {cache_dir}"
|
|
|
}, status=404)
|
|
|
-
|
|
|
+
|
|
|
except Exception as e:
|
|
|
print(f"Error in handle_delete_model: {str(e)}")
|
|
|
traceback.print_exc()
|
|
@@ -543,6 +544,20 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
|
|
|
|
+ async def handle_get_topology(self, request):
|
|
|
+ try:
|
|
|
+ topology = self.node.current_topology
|
|
|
+ if topology:
|
|
|
+ return web.json_response(topology.to_json())
|
|
|
+ else:
|
|
|
+ return web.json_response({})
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 2: traceback.print_exc()
|
|
|
+ return web.json_response(
|
|
|
+ {"detail": f"Error getting topology: {str(e)}"},
|
|
|
+ status=500
|
|
|
+ )
|
|
|
+
|
|
|
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
|
|
runner = web.AppRunner(self.app)
|
|
|
await runner.setup()
|