|
@@ -219,54 +219,62 @@ class ChatGPTAPI:
|
|
|
|
|
|
async def handle_model_support(self, request):
|
|
|
try:
|
|
|
- model_pool = {}
|
|
|
-
|
|
|
- for model_name, pretty in pretty_name.items():
|
|
|
- if model_name in model_cards:
|
|
|
- model_info = model_cards[model_name]
|
|
|
-
|
|
|
- # Get required engines from the node's topology directly
|
|
|
- required_engines = list(dict.fromkeys(
|
|
|
- [engine_name for engine_list in self.node.topology_inference_engines_pool
|
|
|
- for engine_name in engine_list
|
|
|
- if engine_name is not None] +
|
|
|
- [self.inference_engine_classname]
|
|
|
- ))
|
|
|
- # Check if model supports required engines
|
|
|
- if all(map(lambda engine: engine in model_info["repo"], required_engines)):
|
|
|
- shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
- if shard:
|
|
|
- # Use HFShardDownloader to check status without initiating download
|
|
|
- downloader = HFShardDownloader(quick_check=True) # quick_check=True prevents downloads
|
|
|
- downloader.current_shard = shard
|
|
|
- downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
- status = await downloader.get_shard_download_status()
|
|
|
- if DEBUG >= 2:
|
|
|
- print(f"Download status for {model_name}: {status}")
|
|
|
-
|
|
|
- # Get overall percentage from status
|
|
|
- download_percentage = status.get("overall") if status else None
|
|
|
- total_size = status.get("total_size") if status else None
|
|
|
- total_downloaded = status.get("total_downloaded") if status else False
|
|
|
- if DEBUG >= 2 and download_percentage is not None:
|
|
|
- print(f"Overall download percentage for {model_name}: {download_percentage}")
|
|
|
-
|
|
|
- model_pool[model_name] = {
|
|
|
- "name": pretty,
|
|
|
- "downloaded": download_percentage == 100 if download_percentage is not None else False,
|
|
|
- "download_percentage": download_percentage,
|
|
|
- "total_size": total_size,
|
|
|
- "total_downloaded": total_downloaded
|
|
|
- }
|
|
|
-
|
|
|
- return web.json_response({"model pool": model_pool})
|
|
|
+ response = web.StreamResponse(
|
|
|
+ status=200,
|
|
|
+ reason='OK',
|
|
|
+ headers={
|
|
|
+ 'Content-Type': 'text/event-stream',
|
|
|
+ 'Cache-Control': 'no-cache',
|
|
|
+ 'Connection': 'keep-alive',
|
|
|
+ }
|
|
|
+ )
|
|
|
+ await response.prepare(request)
|
|
|
+
|
|
|
+ for model_name, pretty in pretty_name.items():
|
|
|
+ if model_name in model_cards:
|
|
|
+ model_info = model_cards[model_name]
|
|
|
+
|
|
|
+ required_engines = list(dict.fromkeys(
|
|
|
+ [engine_name for engine_list in self.node.topology_inference_engines_pool
|
|
|
+ for engine_name in engine_list
|
|
|
+ if engine_name is not None] +
|
|
|
+ [self.inference_engine_classname]
|
|
|
+ ))
|
|
|
+
|
|
|
+ if all(map(lambda engine: engine in model_info["repo"], required_engines)):
|
|
|
+ shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
+ if shard:
|
|
|
+ downloader = HFShardDownloader(quick_check=True)
|
|
|
+ downloader.current_shard = shard
|
|
|
+ downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
+ status = await downloader.get_shard_download_status()
|
|
|
+
|
|
|
+ download_percentage = status.get("overall") if status else None
|
|
|
+ total_size = status.get("total_size") if status else None
|
|
|
+ total_downloaded = status.get("total_downloaded") if status else False
|
|
|
+
|
|
|
+ model_data = {
|
|
|
+ model_name: {
|
|
|
+ "name": pretty,
|
|
|
+ "downloaded": download_percentage == 100 if download_percentage is not None else False,
|
|
|
+ "download_percentage": download_percentage,
|
|
|
+ "total_size": total_size,
|
|
|
+ "total_downloaded": total_downloaded
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
|
+
|
|
|
+ await response.write(b"data: [DONE]\n\n")
|
|
|
+ return response
|
|
|
+
|
|
|
except Exception as e:
|
|
|
- print(f"Error in handle_model_support: {str(e)}")
|
|
|
- traceback.print_exc()
|
|
|
- return web.json_response(
|
|
|
- {"detail": f"Server error: {str(e)}"},
|
|
|
- status=500
|
|
|
- )
|
|
|
+ print(f"Error in handle_model_support: {str(e)}")
|
|
|
+ traceback.print_exc()
|
|
|
+ return web.json_response(
|
|
|
+ {"detail": f"Server error: {str(e)}"},
|
|
|
+ status=500
|
|
|
+ )
|
|
|
|
|
|
async def handle_get_models(self, request):
|
|
|
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|