|
@@ -204,51 +204,51 @@ class ChatGPTAPI:
|
|
|
|
|
|
async def handle_model_support(self, request):
|
|
|
try:
|
|
|
- model_pool = {}
|
|
|
-
|
|
|
- for model_name, pretty in pretty_name.items():
|
|
|
- if model_name in model_cards:
|
|
|
- model_info = model_cards[model_name]
|
|
|
-
|
|
|
- # Get required engines
|
|
|
- required_engines = list(dict.fromkeys([
|
|
|
- inference_engine_classes.get(engine_name, None)
|
|
|
- for engine_list in self.node.topology_inference_engines_pool
|
|
|
- for engine_name in engine_list
|
|
|
- if engine_name is not None
|
|
|
- ] + [self.inference_engine_classname]))
|
|
|
-
|
|
|
- # Check if model supports required engines
|
|
|
- if all(map(lambda engine: engine in model_info["repo"], required_engines)):
|
|
|
- shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
- if shard:
|
|
|
- # Use HFShardDownloader to check status without initiating download
|
|
|
- downloader = HFShardDownloader(quick_check=True) # quick_check=True prevents downloads
|
|
|
- downloader.current_shard = shard
|
|
|
- downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
- status = await downloader.get_shard_download_status()
|
|
|
- if DEBUG >= 2:
|
|
|
- print(f"Download status for {model_name}: {status}")
|
|
|
-
|
|
|
- # Get overall percentage from status
|
|
|
- download_percentage = status.get("overall") if status else None
|
|
|
- if DEBUG >= 2 and download_percentage is not None:
|
|
|
- print(f"Overall download percentage for {model_name}: {download_percentage}")
|
|
|
-
|
|
|
- model_pool[model_name] = {
|
|
|
- "name": pretty,
|
|
|
- "downloaded": download_percentage == 100 if download_percentage is not None else False,
|
|
|
- "download_percentage": download_percentage
|
|
|
- }
|
|
|
-
|
|
|
- return web.json_response({"model pool": model_pool})
|
|
|
+ model_pool = {}
|
|
|
+
|
|
|
+ for model_name, pretty in pretty_name.items():
|
|
|
+ if model_name in model_cards:
|
|
|
+ model_info = model_cards[model_name]
|
|
|
+
|
|
|
+ # Get required engines
|
|
|
+ required_engines = list(dict.fromkeys([
|
|
|
+ inference_engine_classes.get(engine_name, None)
|
|
|
+ for engine_list in self.node.topology_inference_engines_pool
|
|
|
+ for engine_name in engine_list
|
|
|
+ if engine_name is not None
|
|
|
+ ] + [self.inference_engine_classname]))
|
|
|
+
|
|
|
+ # Check if model supports required engines
|
|
|
+ if all(map(lambda engine: engine in model_info["repo"], required_engines)):
|
|
|
+ shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
+ if shard:
|
|
|
+ # Use HFShardDownloader to check status without initiating download
|
|
|
+ downloader = HFShardDownloader(quick_check=True) # quick_check=True prevents downloads
|
|
|
+ downloader.current_shard = shard
|
|
|
+ downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
+ status = await downloader.get_shard_download_status()
|
|
|
+ if DEBUG >= 2:
|
|
|
+ print(f"Download status for {model_name}: {status}")
|
|
|
+
|
|
|
+ # Get overall percentage from status
|
|
|
+ download_percentage = status.get("overall") if status else None
|
|
|
+ if DEBUG >= 2 and download_percentage is not None:
|
|
|
+ print(f"Overall download percentage for {model_name}: {download_percentage}")
|
|
|
+
|
|
|
+ model_pool[model_name] = {
|
|
|
+ "name": pretty,
|
|
|
+ "downloaded": download_percentage == 100 if download_percentage is not None else False,
|
|
|
+ "download_percentage": download_percentage
|
|
|
+ }
|
|
|
+
|
|
|
+ return web.json_response({"model pool": model_pool})
|
|
|
except Exception as e:
|
|
|
- print(f"Error in handle_model_support: {str(e)}")
|
|
|
- traceback.print_exc()
|
|
|
- return web.json_response(
|
|
|
- {"detail": f"Server error: {str(e)}"},
|
|
|
- status=500
|
|
|
- )
|
|
|
+ print(f"Error in handle_model_support: {str(e)}")
|
|
|
+ traceback.print_exc()
|
|
|
+ return web.json_response(
|
|
|
+ {"detail": f"Server error: {str(e)}"},
|
|
|
+ status=500
|
|
|
+ )
|
|
|
|
|
|
async def handle_get_models(self, request):
|
|
|
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|