cadenmackenzie 8 місяців тому
батько
коміт
8135437c4a
2 змінених файлів з 85 додано та 82 видалено
  1. 44 44
      exo/api/chatgpt_api.py
  2. 41 38
      exo/download/hf/hf_helpers.py

+ 44 - 44
exo/api/chatgpt_api.py

@@ -204,51 +204,51 @@ class ChatGPTAPI:
 
   async def handle_model_support(self, request):
     try:
-        model_pool = {}
-        
-        for model_name, pretty in pretty_name.items():
-            if model_name in model_cards:
-                model_info = model_cards[model_name]
-                
-                # Get required engines
-                required_engines = list(dict.fromkeys([
-                    inference_engine_classes.get(engine_name, None) 
-                    for engine_list in self.node.topology_inference_engines_pool 
-                    for engine_name in engine_list 
-                    if engine_name is not None
-                ] + [self.inference_engine_classname]))
-                
-                # Check if model supports required engines
-                if all(map(lambda engine: engine in model_info["repo"], required_engines)):
-                    shard = build_base_shard(model_name, self.inference_engine_classname)
-                    if shard:
-                        # Use HFShardDownloader to check status without initiating download
-                        downloader = HFShardDownloader(quick_check=True)  # quick_check=True prevents downloads
-                        downloader.current_shard = shard
-                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-                        status = await downloader.get_shard_download_status()
-                        if DEBUG >= 2:
-                            print(f"Download status for {model_name}: {status}")
-                        
-                        # Get overall percentage from status
-                        download_percentage = status.get("overall") if status else None
-                        if DEBUG >= 2 and download_percentage is not None:
-                            print(f"Overall download percentage for {model_name}: {download_percentage}")
-                        
-                        model_pool[model_name] = {
-                            "name": pretty,
-                            "downloaded": download_percentage == 100 if download_percentage is not None else False,
-                            "download_percentage": download_percentage
-                        }
-        
-        return web.json_response({"model pool": model_pool})
+      model_pool = {}
+      
+      for model_name, pretty in pretty_name.items():
+        if model_name in model_cards:
+          model_info = model_cards[model_name]
+          
+          # Get required engines
+          required_engines = list(dict.fromkeys([
+              inference_engine_classes.get(engine_name, None) 
+              for engine_list in self.node.topology_inference_engines_pool 
+              for engine_name in engine_list 
+              if engine_name is not None
+          ] + [self.inference_engine_classname]))
+          
+          # Check if model supports required engines
+          if all(map(lambda engine: engine in model_info["repo"], required_engines)):
+            shard = build_base_shard(model_name, self.inference_engine_classname)
+            if shard:
+                # Use HFShardDownloader to check status without initiating download
+              downloader = HFShardDownloader(quick_check=True)  # quick_check=True prevents downloads
+              downloader.current_shard = shard
+              downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+              status = await downloader.get_shard_download_status()
+              if DEBUG >= 2:
+                  print(f"Download status for {model_name}: {status}")
+              
+              # Get overall percentage from status
+              download_percentage = status.get("overall") if status else None
+              if DEBUG >= 2 and download_percentage is not None:
+                  print(f"Overall download percentage for {model_name}: {download_percentage}")
+              
+              model_pool[model_name] = {
+                  "name": pretty,
+                  "downloaded": download_percentage == 100 if download_percentage is not None else False,
+                  "download_percentage": download_percentage
+              }
+      
+      return web.json_response({"model pool": model_pool})
     except Exception as e:
-        print(f"Error in handle_model_support: {str(e)}")
-        traceback.print_exc()
-        return web.json_response(
-            {"detail": f"Server error: {str(e)}"}, 
-            status=500
-        )
+      print(f"Error in handle_model_support: {str(e)}")
+      traceback.print_exc()
+      return web.json_response(
+        {"detail": f"Server error: {str(e)}"}, 
+        status=500
+      )
 
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])

+ 41 - 38
exo/download/hf/hf_helpers.py

@@ -427,45 +427,48 @@ async def get_file_download_percentage(
     repo_id: str,
     revision: str,
     file_path: str,
-    snapshot_dir: Path
+    snapshot_dir: Path,
 ) -> float:
-    """
+  """
     Calculate the download percentage for a file by comparing local and remote sizes.
     """
-    try:
-        local_path = snapshot_dir / file_path
-        if not await aios.path.exists(local_path):
-            return 0
-
-        # Get local file size first
-        local_size = await aios.path.getsize(local_path)
-        if local_size == 0:
-            return 0
-            
-        # Check remote size
-        base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
-        url = urljoin(base_url, file_path)
-        headers = await get_auth_headers()
-        
-        # Use HEAD request with redirect following for all files
-        async with session.head(url, headers=headers, allow_redirects=True) as response:
-            if response.status != 200:
-                if DEBUG >= 2: print(f"Failed to get remote file info for {file_path}: {response.status}")
-                return 0
-                
-            remote_size = int(response.headers.get('Content-Length', 0))
-            
-            if remote_size == 0:
-                if DEBUG >= 2: print(f"Remote size is 0 for {file_path}")
-                return 0
-                
-            # Only return 100% if sizes match exactly
-            if local_size == remote_size:
-                return 100.0
-                
-            # Calculate percentage based on sizes
-            return (local_size / remote_size) * 100 if remote_size > 0 else 0
-        
-    except Exception as e:
-        if DEBUG >= 2: print(f"Error checking file download status for {file_path}: {e}")
+  try:
+    local_path = snapshot_dir / file_path
+    if not await aios.path.exists(local_path):
+      return 0
+
+    # Get local file size first
+    local_size = await aios.path.getsize(local_path)
+    if local_size == 0:
+      return 0
+
+    # Check remote size
+    base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+    url = urljoin(base_url, file_path)
+    headers = await get_auth_headers()
+
+    # Use HEAD request with redirect following for all files
+    async with session.head(url, headers=headers, allow_redirects=True) as response:
+      if response.status != 200:
+        if DEBUG >= 2:
+          print(f"Failed to get remote file info for {file_path}: {response.status}")
         return 0
+
+      remote_size = int(response.headers.get('Content-Length', 0))
+
+      if remote_size == 0:
+        if DEBUG >= 2:
+          print(f"Remote size is 0 for {file_path}")
+        return 0
+
+      # Only return 100% if sizes match exactly
+      if local_size == remote_size:
+        return 100.0
+
+      # Calculate percentage based on sizes
+      return (local_size / remote_size) * 100 if remote_size > 0 else 0
+
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Error checking file download status for {file_path}: {e}")
+    return 0