Ver Fonte

adding logic to check which models are downloaded

cadenmackenzie há 9 meses atrás
pai
commit
c7dd3126b1
4 ficheiros alterados com 155 adições e 103 exclusões
  1. 96 18
      exo/api/chatgpt_api.py
  2. 10 0
      exo/tinychat/index.css
  3. 4 4
      exo/tinychat/index.html
  4. 45 81
      exo/tinychat/index.js

+ 96 - 18
exo/api/chatgpt_api.py

@@ -17,6 +17,7 @@ from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable
+import os
 
 
 class Message:
@@ -200,25 +201,102 @@ class ChatGPTAPI:
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
 
+  def is_model_downloaded(self, model_name):
+    if DEBUG >= 2:
+        print(f"\nChecking if model {model_name} is downloaded:")
+    
+    cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
+    repo = get_repo(model_name, self.inference_engine_classname)
+    
+    if DEBUG >= 2:
+        print(f"  Cache dir: {cache_dir}")
+        print(f"  Repo: {repo}")
+        print(f"  Engine: {self.inference_engine_classname}")
+    
+    if not repo:
+        return False
+
+    # Convert repo path (e.g. "mlx-community/Llama-3.2-1B-Instruct-4bit")
+    # to directory format (e.g. "models--mlx-community--Llama-3.2-1B-Instruct-4bit")
+    repo_parts = repo.split('/')
+    formatted_path = f"models--{repo_parts[0]}--{repo_parts[1]}"
+    repo_path = cache_dir / formatted_path / "snapshots"
+    
+    if DEBUG >= 2:
+        print(f"  Looking in: {repo_path}")
+        
+    if repo_path.exists():
+        # Look for the most recent snapshot directory
+        snapshots = list(repo_path.glob("*"))
+        if snapshots:
+            latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)
+            
+            # Check for model files and their index files
+            model_files = (
+                list(latest_snapshot.glob("model.safetensors")) +
+                list(latest_snapshot.glob("model.safetensors.index.json")) +
+                list(latest_snapshot.glob("*.mlx"))
+            )
+            
+            if DEBUG >= 2:
+                print(f"  Latest snapshot: {latest_snapshot}")
+                print(f"  Found files: {model_files}")
+                
+            # Model is considered downloaded if we find either:
+            # 1. model.safetensors file
+            # 2. model.safetensors.index.json file (for sharded models)
+            # 3. *.mlx file
+            return len(model_files) > 0
+    
+    if DEBUG >= 2:
+        print("  No valid model files found")
+    return False
+
   async def handle_model_support(self, request):
-    return web.json_response({
-      "model pool": {
-        model_name: pretty_name.get(model_name, model_name) 
-        for model_name in [
-          model_id for model_id, model_info in model_cards.items() 
-          if all(map(
-            lambda engine: engine in model_info["repo"],
-            list(dict.fromkeys([
-              inference_engine_classes.get(engine_name, None) 
-              for engine_list in self.node.topology_inference_engines_pool 
-              for engine_name in engine_list 
-              if engine_name is not None
-            ] + [self.inference_engine_classname]))
-          ))
-        ]
-      }
-    })
-  
+    try:
+        print("\n=== Model Support Handler Started ===")
+        model_pool = {}
+        
+        print("\nAvailable Models:")
+        print("-" * 50)
+        for model_name, pretty in pretty_name.items():
+            print(f"\nChecking model: {model_name}")
+            if model_name in model_cards:
+                model_info = model_cards[model_name]
+                print(f"Model info: {model_info}")
+                
+                # Get required engines
+                required_engines = list(dict.fromkeys([
+                    inference_engine_classes.get(engine_name, None) 
+                    for engine_list in self.node.topology_inference_engines_pool 
+                    for engine_name in engine_list 
+                    if engine_name is not None
+                ] + [self.inference_engine_classname]))
+                print(f"Required engines: {required_engines}")
+                
+                # Check if model supports required engines
+                if all(map(lambda engine: engine in model_info["repo"], required_engines)):
+                    is_downloaded = self.is_model_downloaded(model_name)
+                    print(f"Model {model_name} download status: {is_downloaded}")
+                    
+                    model_pool[model_name] = {
+                        "name": pretty,
+                        "downloaded": is_downloaded
+                    }
+        
+        print("\nFinal model pool:")
+        print(json.dumps(model_pool, indent=2))
+        print("\n=== Model Support Handler Completed ===\n")
+        
+        return web.json_response({"model pool": model_pool})
+    except Exception as e:
+        print(f"\nError in handle_model_support: {str(e)}")
+        traceback.print_exc()
+        return web.json_response(
+            {"detail": f"Server error: {str(e)}"}, 
+            status=500
+        )
+
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 

+ 10 - 0
exo/tinychat/index.css

@@ -490,4 +490,14 @@ p {
 
 .clear-history-button i {
   font-size: 14px;
+}
+
+.model-select option.model-not-downloaded {
+  color: #888;
+  font-style: italic;
+}
+
+.model-select option.model-downloaded {
+  color: inherit;
+  font-weight: 500;
 }

+ 4 - 4
exo/tinychat/index.html

@@ -26,13 +26,13 @@
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
      <!-- Error Toast -->
-    <div x-show="errorMessage" x-transition.opacity class="toast">
+    <div x-show="errorMessage !== null" x-transition.opacity class="toast">
         <div class="toast-header">
-            <span class="toast-error-message" x-text="errorMessage.basic"></span>
+            <span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
             <div class="toast-header-buttons">
                 <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
                         class="toast-expand-button" 
-                        x-show="errorMessage.stack">
+                        x-show="errorMessage?.stack">
                     <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
                 </button>
                 <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
@@ -41,7 +41,7 @@
             </div>
         </div>
         <div class="toast-content" x-show="errorExpanded" x-transition>
-            <span x-text="errorMessage.stack"></span>
+            <span x-text="errorMessage?.stack || ''"></span>
         </div>
     </div>
 <div class="model-selector">

+ 45 - 81
exo/tinychat/index.js

@@ -10,13 +10,15 @@ document.addEventListener("alpine:init", () => {
     // historical state
     histories: JSON.parse(localStorage.getItem("histories")) || [],
 
-    home: 0,
-    generating: false,
-    endpoint: `${window.location.origin}/v1`,
+    // Initialize error message structure
     errorMessage: null,
     errorExpanded: false,
     errorTimeout: null,
 
+    home: 0,
+    generating: false,
+    endpoint: `${window.location.origin}/v1`,
+
     // performance tracking
     time_till_first: 0,
     tokens_per_second: 0,
@@ -76,51 +78,37 @@ document.addEventListener("alpine:init", () => {
 
     async populateSelector() {
       try {
+        console.log("Fetching model pool...");
         const response = await fetch(`${window.location.origin}/modelpool`);
-        const responseText = await response.text(); // Get raw response text first
-        
         if (!response.ok) {
-          throw new Error(`HTTP error! status: ${response.status}`);
-        }
-        
-        // Try to parse the response text
-        let responseJson;
-        try {
-          responseJson = JSON.parse(responseText);
-        } catch (parseError) {
-          console.error('Failed to parse JSON:', parseError);
-          throw new Error(`Invalid JSON response: ${responseText}`);
+          const errorText = await response.text();
+          throw new Error(`HTTP error! status: ${response.status}\n${errorText}`);
         }
 
-        const sel = document.querySelector(".model-select");
-        if (!sel) {
-          throw new Error("Could not find model selector element");
-        }
+        const data = await response.json();
+        console.log("Received model pool data:", data);
 
-        // Clear the current options and add new ones
+        const sel = document.querySelector('.model-select');
         sel.innerHTML = '';
-          
-        const modelDict = responseJson["model pool"];
-        if (!modelDict) {
-          throw new Error("Response missing 'model pool' property");
-        }
 
-        Object.entries(modelDict).forEach(([key, value]) => {
+        // Convert the model pool to an array of [key, value] pairs and sort by name
+        const sortedModels = Object.entries(data["model pool"]).sort((a, b) => 
+          a[1].name.localeCompare(b[1].name)
+        );
+
+        console.log("Sorted models:", sortedModels);
+
+        sortedModels.forEach(([key, value]) => {
           const opt = document.createElement("option");
           opt.value = key;
-          opt.textContent = value;
+          opt.textContent = `${value.name}${value.downloaded ? ' (downloaded)' : ''}`;
+          opt.classList.add(value.downloaded ? 'model-downloaded' : 'model-not-downloaded');
           sel.appendChild(opt);
+          console.log(`Added model: ${key} (${value.name}) - Downloaded: ${value.downloaded}`);
         });
-
-        // Set initial value to the first model
-        const firstKey = Object.keys(modelDict)[0];
-        if (firstKey) {
-          sel.value = firstKey;
-          this.cstate.selectedModel = firstKey;
-        }
       } catch (error) {
         console.error("Error populating model selector:", error);
-        this.errorMessage = `Failed to load models: ${error.message}`;
+        this.setError(error);
       }
     },
 
@@ -169,29 +157,7 @@ document.addEventListener("alpine:init", () => {
         this.processMessage(value);
       } catch (error) {
         console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
+        this.setError(error);
         this.generating = false;
       }
     },
@@ -309,29 +275,7 @@ document.addEventListener("alpine:init", () => {
         }
       } catch (error) {
         console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
+        this.setError(error);
       } finally {
         this.generating = false;
       }
@@ -467,6 +411,26 @@ document.addEventListener("alpine:init", () => {
         this.fetchDownloadProgress();
       }, 1000); // Poll every second
     },
+
+    // Add a helper method to set errors consistently
+    setError(error) {
+      this.errorMessage = {
+        basic: error.message || "An unknown error occurred",
+        stack: error.stack || ""
+      };
+      this.errorExpanded = false;
+      
+      if (this.errorTimeout) {
+        clearTimeout(this.errorTimeout);
+      }
+
+      if (!this.errorExpanded) {
+        this.errorTimeout = setTimeout(() => {
+          this.errorMessage = null;
+          this.errorExpanded = false;
+        }, 30 * 1000);
+      }
+    },
   }));
 });