Procházet zdrojové kódy

removing is_model_downloaded method and changing how downloaded variable is set

cadenmackenzie před 9 měsíci
rodič
revize
dfcf513d55
1 změnil soubory, kde provedl 1 přidání a 53 odebrání
  1. 1 53
      exo/api/chatgpt_api.py

+ 1 - 53
exo/api/chatgpt_api.py

@@ -203,57 +203,6 @@ class ChatGPTAPI:
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
 
-  def is_model_downloaded(self, model_name):
-    if DEBUG >= 2:
-        print(f"\nChecking if model {model_name} is downloaded:")
-    
-    cache_dir = get_hf_home() / "hub"
-    repo = get_repo(model_name, self.inference_engine_classname)
-    
-    if DEBUG >= 2:
-        print(f"  Cache dir: {cache_dir}")
-        print(f"  Repo: {repo}")
-        print(f"  Engine: {self.inference_engine_classname}")
-    
-    if not repo:
-        return False
-
-    # Convert repo path (e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit")
-    # to directory format (e.g. "models--mlx-community--Llama-3.2-1B-Instruct-4bit")
-    repo_parts = repo.split('/')
-    formatted_path = f"models--{repo_parts[0]}--{repo_parts[1]}"
-    repo_path = cache_dir / formatted_path / "snapshots"
-    
-    if DEBUG >= 2:
-        print(f"  Looking in: {repo_path}")
-        
-    if repo_path.exists():
-        # Look for the most recent snapshot directory
-        snapshots = list(repo_path.glob("*"))
-        if snapshots:
-            latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)
-            
-            # Check for model files and their index files
-            model_files = (
-                list(latest_snapshot.glob("model.safetensors")) +
-                list(latest_snapshot.glob("model.safetensors.index.json")) +
-                list(latest_snapshot.glob("*.mlx"))
-            )
-            
-            if DEBUG >= 2:
-                print(f"  Latest snapshot: {latest_snapshot}")
-                print(f"  Found files: {model_files}")
-                
-            # Model is considered downloaded if we find either:
-            # 1. model.safetensors file
-            # 2. model.safetensors.index.json file (for sharded models)
-            # 3. *.mlx file
-            return len(model_files) > 0
-    
-    if DEBUG >= 2:
-        print("  No valid model files found")
-    return False
-
   async def handle_model_support(self, request):
     try:
         model_pool = {}
@@ -272,7 +221,6 @@ class ChatGPTAPI:
                 
                 # Check if model supports required engines
                 if all(map(lambda engine: engine in model_info["repo"], required_engines)):
-                    # Create a shard for status checking
                     shard = build_base_shard(model_name, self.inference_engine_classname)
                     if shard:
                         downloader = HFShardDownloader()
@@ -293,7 +241,7 @@ class ChatGPTAPI:
                         
                         model_pool[model_name] = {
                             "name": pretty,
-                            "downloaded": self.is_model_downloaded(model_name),
+                            "downloaded": download_percentage == 100 if download_percentage is not None else False,
                             "download_percentage": download_percentage
                         }