Prechádzať zdrojové kódy

Merge pull request #2 from cadenmackenzie/downloadedModelsV2Revisions

working versions
Caden MacKenzie 9 mesiacov pred
rodič
commit
bd2985aebd

+ 24 - 59
exo/api/chatgpt_api.py

@@ -19,6 +19,7 @@ from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable
 import os
 from exo.download.hf.hf_helpers import get_hf_home
+from exo.download.hf.hf_shard_download import HFShardDownloader
 
 
 class Message:
@@ -202,57 +203,6 @@ class ChatGPTAPI:
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
 
-  def is_model_downloaded(self, model_name):
-    if DEBUG >= 2:
-        print(f"\nChecking if model {model_name} is downloaded:")
-    
-    cache_dir = get_hf_home() / "hub"
-    repo = get_repo(model_name, self.inference_engine_classname)
-    
-    if DEBUG >= 2:
-        print(f"  Cache dir: {cache_dir}")
-        print(f"  Repo: {repo}")
-        print(f"  Engine: {self.inference_engine_classname}")
-    
-    if not repo:
-        return False
-
-    # Convert repo path (e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit")
-    # to directory format (e.g. "models--mlx-community--Llama-3.2-1B-Instruct-4bit")
-    repo_parts = repo.split('/')
-    formatted_path = f"models--{repo_parts[0]}--{repo_parts[1]}"
-    repo_path = cache_dir / formatted_path / "snapshots"
-    
-    if DEBUG >= 2:
-        print(f"  Looking in: {repo_path}")
-        
-    if repo_path.exists():
-        # Look for the most recent snapshot directory
-        snapshots = list(repo_path.glob("*"))
-        if snapshots:
-            latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)
-            
-            # Check for model files and their index files
-            model_files = (
-                list(latest_snapshot.glob("model.safetensors")) +
-                list(latest_snapshot.glob("model.safetensors.index.json")) +
-                list(latest_snapshot.glob("*.mlx"))
-            )
-            
-            if DEBUG >= 2:
-                print(f"  Latest snapshot: {latest_snapshot}")
-                print(f"  Found files: {model_files}")
-                
-            # Model is considered downloaded if we find either:
-            # 1. model.safetensors file
-            # 2. model.safetensors.index.json file (for sharded models)
-            # 3. *.mlx file
-            return len(model_files) > 0
-    
-    if DEBUG >= 2:
-        print("  No valid model files found")
-    return False
-
   async def handle_model_support(self, request):
     try:
         model_pool = {}
@@ -271,14 +221,29 @@ class ChatGPTAPI:
                 
                 # Check if model supports required engines
                 if all(map(lambda engine: engine in model_info["repo"], required_engines)):
-                    is_downloaded = self.is_model_downloaded(model_name)
-                    if DEBUG >= 2:
-                        print(f"Model {model_name} download status: {is_downloaded}")
-                    
-                    model_pool[model_name] = {
-                        "name": pretty,
-                        "downloaded": is_downloaded
-                    }
+                    shard = build_base_shard(model_name, self.inference_engine_classname)
+                    if shard:
+                        downloader = HFShardDownloader()
+                        downloader.current_shard = shard
+                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+                        status = await downloader.get_shard_download_status()
+                        if DEBUG >= 2:
+                            print(f"Download status for {model_name}: {status}")
+                        
+                        # Calculate overall percentage if we have status
+                        download_percentage = None
+                        if status:
+                            percentages = list(status.values())
+                            if percentages:
+                                download_percentage = sum(percentages) / len(percentages)
+                                if DEBUG >= 2:
+                                    print(f"Calculated download percentage for {model_name}: {download_percentage}")
+                        
+                        model_pool[model_name] = {
+                            "name": pretty,
+                            "downloaded": download_percentage == 100 if download_percentage is not None else False,
+                            "download_percentage": download_percentage
+                        }
         
         return web.json_response({"model pool": model_pool})
     except Exception as e:

+ 64 - 2
exo/download/hf/hf_shard_download.py

@@ -1,13 +1,18 @@
 import asyncio
 import traceback
 from pathlib import Path
-from typing import Dict, List, Tuple
+from typing import Dict, List, Tuple, Optional
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
+from exo.download.hf.hf_helpers import (
+    download_repo_files, RepoProgressEvent, get_weight_map, 
+    get_allow_patterns, get_repo_root, fetch_file_list, get_local_snapshot_dir
+)
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.models import model_cards, get_repo
+import aiohttp
+from aiofiles import os as aios
 
 
 class HFShardDownloader(ShardDownloader):
@@ -17,8 +22,13 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads: Dict[Shard, asyncio.Task] = {}
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+    self.current_shard: Optional[Shard] = None
+    self.current_repo_id: Optional[str] = None
+    self.revision: str = "main"
 
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    self.current_shard = shard
+    self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
     repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
@@ -77,3 +87,55 @@ class HFShardDownloader(ShardDownloader):
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return self._on_progress
+
+  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+    if not self.current_shard or not self.current_repo_id:
+        if DEBUG >= 2: print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
+        return None
+            
+    try:
+        snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
+        if not snapshot_dir:
+            if DEBUG >= 2: print(f"No snapshot directory found for {self.current_repo_id}")
+            return None
+
+        # Get the weight map to know what files we need
+        weight_map = await get_weight_map(self.current_repo_id, self.revision)
+        if not weight_map:
+            if DEBUG >= 2: print(f"No weight map found for {self.current_repo_id}")
+            return None
+        
+        # Get the patterns for this shard
+        patterns = get_allow_patterns(weight_map, self.current_shard)
+        
+        # First check which files exist locally
+        status = {}
+        local_files = []
+        local_sizes = {}
+        
+        for pattern in patterns:
+            if pattern.endswith('safetensors') or pattern.endswith('mlx'):
+                file_path = snapshot_dir / pattern
+                if await aios.path.exists(file_path):
+                    local_size = await aios.path.getsize(file_path)
+                    local_files.append(pattern)
+                    local_sizes[pattern] = local_size
+
+        # Only fetch remote info if we found local files
+        if local_files:
+            async with aiohttp.ClientSession() as session:
+                file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+                
+                for pattern in local_files:
+                    for file in file_list:
+                        if file["path"].endswith(pattern):
+                            status[pattern] = (local_sizes[pattern] / file["size"]) * 100
+                            break
+
+        return status
+      
+    except Exception as e:
+        if DEBUG >= 2:
+            print(f"Error getting shard download status: {e}")
+            traceback.print_exc()
+        return None

+ 11 - 1
exo/download/shard_download.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Dict
 from pathlib import Path
 from exo.inference.shard import Shard
 from exo.download.download_progress import RepoProgressEvent
@@ -26,6 +26,16 @@ class ShardDownloader(ABC):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
 
+  @abstractmethod
+  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+    """Get the download status of shards.
+    
+    Returns:
+        Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100),
+        or None if status cannot be determined
+    """
+    pass
+
 
 class NoopShardDownloader(ShardDownloader):
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:

+ 35 - 11
exo/tinychat/index.js

@@ -34,6 +34,8 @@ document.addEventListener("alpine:init", () => {
     // Pending message storage
     pendingMessage: null,
 
+    modelPoolInterval: null,
+
     init() {
       // Clean up any pending messages
       localStorage.removeItem("pendingMessage");
@@ -43,6 +45,9 @@ document.addEventListener("alpine:init", () => {
       
       // Call populateSelector immediately after initialization
       this.populateSelector();
+      this.modelPoolInterval = setInterval(() => {
+        this.populateSelector();
+      }, 5000);
     },
 
     removeHistory(cstate) {
@@ -83,21 +88,40 @@ document.addEventListener("alpine:init", () => {
       try {
         const response = await fetch(`${window.location.origin}/modelpool`);
         if (!response.ok) {
-          const errorText = await response.text();
-          throw new Error(`HTTP error! status: ${response.status}\n${errorText}`);
+          throw new Error(`HTTP error! status: ${response.status}`);
         }
 
         const data = await response.json();
+        console.log("Model pool data:", data);
+        
         const sel = document.querySelector('.model-select');
-        sel.innerHTML = '';
-
-        // Use the model pool entries in their original order
-        Object.entries(data["model pool"]).forEach(([key, value]) => {
-          const opt = document.createElement("option");
-          opt.value = key;
-          opt.textContent = `${value.name}${value.downloaded ? ' (downloaded)' : ''}`;
-          sel.appendChild(opt);
-        });
+        
+        // Only create options if they don't exist
+        if (sel.children.length === 0) {
+          Object.entries(data["model pool"]).forEach(([key, value]) => {
+            const opt = document.createElement("option");
+            opt.value = key;
+            opt.dataset.modelName = value.name;  // Store base name in dataset
+            opt.textContent = value.name;
+            sel.appendChild(opt);
+          });
+        }
+        
+        // Update existing options text
+        Array.from(sel.options).forEach(opt => {
+          const modelInfo = data["model pool"][opt.value];
+          if (modelInfo) {
+              let displayText = modelInfo.name;
+              if (modelInfo.download_percentage != null) {
+                  if (modelInfo.downloaded) {
+                      displayText += ' (downloaded)';
+                  } else {
+                      displayText += ` (${Math.round(modelInfo.download_percentage)}% downloaded)`;
+                  }
+              }
+              opt.textContent = displayText;
+          }
+      });
       } catch (error) {
         console.error("Error populating model selector:", error);
         this.setError(error);