Browse Source

Merge pull request #456 from cadenmackenzie/downloadedModelsV2

Show downloaded models, improve error handling, ability to delete models, side bar with more detail, button to go back to chat history
Alex Cheema 6 months ago
parent
commit
3adaba6ab8

+ 123 - 9
exo/api/chatgpt_api.py

@@ -2,6 +2,7 @@ import uuid
 import time
 import time
 import asyncio
 import asyncio
 import json
 import json
+import os
 from pathlib import Path
 from pathlib import Path
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 from typing import List, Literal, Union, Dict
 from typing import List, Literal, Union, Dict
@@ -14,10 +15,12 @@ from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown
 from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
-from exo.apputil import create_animation_mp4
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable, Optional
 from typing import Callable, Optional
-import tempfile
+from exo.download.hf.hf_shard_download import HFShardDownloader
+import shutil
+from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
+from exo.apputil import create_animation_mp4
 
 
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -175,6 +178,8 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
+    cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
+    cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
     cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
     cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
     cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
     cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
 
 
@@ -216,12 +221,56 @@ class ChatGPTAPI:
     return web.json_response({"status": "ok"})
     return web.json_response({"status": "ok"})
 
 
   async def handle_model_support(self, request):
   async def handle_model_support(self, request):
-    return web.json_response({
-      "model pool": {
-        model_name: pretty_name.get(model_name, model_name)
-        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
-      }
-    })
+    try:
+        response = web.StreamResponse(
+            status=200,
+            reason='OK',
+            headers={
+                'Content-Type': 'text/event-stream',
+                'Cache-Control': 'no-cache',
+                'Connection': 'keep-alive',
+            }
+        )
+        await response.prepare(request)
+        
+        for model_name, pretty in pretty_name.items():
+            if model_name in model_cards:
+                model_info = model_cards[model_name]
+                
+                if self.inference_engine_classname in model_info.get("repo", {}):
+                    shard = build_base_shard(model_name, self.inference_engine_classname)
+                    if shard:
+                        downloader = HFShardDownloader(quick_check=True)
+                        downloader.current_shard = shard
+                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+                        status = await downloader.get_shard_download_status()
+                        
+                        download_percentage = status.get("overall") if status else None
+                        total_size = status.get("total_size") if status else None
+                        total_downloaded = status.get("total_downloaded") if status else False
+                        
+                        model_data = {
+                            model_name: {
+                                "name": pretty,
+                                "downloaded": download_percentage == 100 if download_percentage is not None else False,
+                                "download_percentage": download_percentage,
+                                "total_size": total_size,
+                                "total_downloaded": total_downloaded
+                            }
+                        }
+                        
+                        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
+        
+        await response.write(b"data: [DONE]\n\n")
+        return response
+        
+    except Exception as e:
+        print(f"Error in handle_model_support: {str(e)}")
+        traceback.print_exc()
+        return web.json_response(
+            {"detail": f"Server error: {str(e)}"}, 
+            status=500
+        )
 
 
   async def handle_get_models(self, request):
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
@@ -372,6 +421,71 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
 
+  async def handle_delete_model(self, request):
+    try:
+      model_name = request.match_info.get('model_name')
+      if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
+      
+      if not model_name or model_name not in model_cards:
+        return web.json_response(
+          {"detail": f"Invalid model name: {model_name}"}, 
+          status=400
+          )
+
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard:
+        return web.json_response(
+          {"detail": "Could not build shard for model"}, 
+          status=400
+        )
+
+      repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+      if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
+      
+      # Get the HF cache directory using the helper function
+      hf_home = get_hf_home()
+      cache_dir = get_repo_root(repo_id)
+      
+      if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
+      
+      if os.path.exists(cache_dir):
+        if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
+        try:
+          shutil.rmtree(cache_dir)
+          return web.json_response({
+            "status": "success", 
+            "message": f"Model {model_name} deleted successfully",
+            "path": str(cache_dir)
+          })
+        except Exception as e:
+          return web.json_response({
+            "detail": f"Failed to delete model files: {str(e)}"
+          }, status=500)
+      else:
+        return web.json_response({
+          "detail": f"Model files not found at {cache_dir}"
+        }, status=404)
+            
+    except Exception as e:
+        print(f"Error in handle_delete_model: {str(e)}")
+        traceback.print_exc()
+        return web.json_response({
+            "detail": f"Server error: {str(e)}"
+        }, status=500)
+
+  async def handle_get_initial_models(self, request):
+    model_data = {}
+    for model_name, pretty in pretty_name.items():
+        model_data[model_name] = {
+            "name": pretty,
+            "downloaded": None,  # Initially unknown
+            "download_percentage": None,  # Change from 0 to null
+            "total_size": None,
+            "total_downloaded": None,
+            "loading": True  # Add loading state
+        }
+    return web.json_response(model_data)
+
   async def handle_create_animation(self, request):
   async def handle_create_animation(self, request):
     try:
     try:
       data = await request.json()
       data = await request.json()

+ 62 - 2
exo/download/hf/hf_helpers.py

@@ -166,10 +166,18 @@ async def download_file(
     downloaded_size = local_file_size
     downloaded_size = local_file_size
     downloaded_this_session = 0
     downloaded_this_session = 0
     mode = 'ab' if use_range_request else 'wb'
     mode = 'ab' if use_range_request else 'wb'
-    if downloaded_size == total_size:
+    percentage = await get_file_download_percentage(
+      session,
+      repo_id,
+      revision,
+      file_path,
+      Path(save_directory)
+    )
+    
+    if percentage == 100:
       if DEBUG >= 2: print(f"File already downloaded: {file_path}")
       if DEBUG >= 2: print(f"File already downloaded: {file_path}")
       if progress_callback:
       if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
       return
       return
 
 
     if response.status == 200:
     if response.status == 200:
@@ -432,6 +440,57 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
   return list(default_patterns | shard_specific_patterns)
 
 
+async def get_file_download_percentage(
+    session: aiohttp.ClientSession,
+    repo_id: str,
+    revision: str,
+    file_path: str,
+    snapshot_dir: Path,
+) -> float:
+  """
+    Calculate the download percentage for a file by comparing local and remote sizes.
+    """
+  try:
+    local_path = snapshot_dir / file_path
+    if not await aios.path.exists(local_path):
+      return 0
+
+    # Get local file size first
+    local_size = await aios.path.getsize(local_path)
+    if local_size == 0:
+      return 0
+
+    # Check remote size
+    base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+    url = urljoin(base_url, file_path)
+    headers = await get_auth_headers()
+
+    # Use HEAD request with redirect following for all files
+    async with session.head(url, headers=headers, allow_redirects=True) as response:
+      if response.status != 200:
+        if DEBUG >= 2:
+          print(f"Failed to get remote file info for {file_path}: {response.status}")
+        return 0
+
+      remote_size = int(response.headers.get('Content-Length', 0))
+
+      if remote_size == 0:
+        if DEBUG >= 2:
+          print(f"Remote size is 0 for {file_path}")
+        return 0
+
+      # Only return 100% if sizes match exactly
+      if local_size == remote_size:
+        return 100.0
+
+      # Calculate percentage based on sizes
+      return (local_size / remote_size) * 100 if remote_size > 0 else 0
+
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Error checking file download status for {file_path}: {e}")
+    return 0
+
 async def has_hf_home_read_access() -> bool:
 async def has_hf_home_read_access() -> bool:
   hf_home = get_hf_home()
   hf_home = get_hf_home()
   try: return await aios.access(hf_home, os.R_OK)
   try: return await aios.access(hf_home, os.R_OK)
@@ -441,3 +500,4 @@ async def has_hf_home_write_access() -> bool:
   hf_home = get_hf_home()
   hf_home = get_hf_home()
   try: return await aios.access(hf_home, os.W_OK)
   try: return await aios.access(hf_home, os.W_OK)
   except OSError: return False
   except OSError: return False
+

+ 90 - 2
exo/download/hf/hf_shard_download.py

@@ -1,13 +1,20 @@
 import asyncio
 import asyncio
 import traceback
 import traceback
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, List, Tuple
+from typing import Dict, List, Tuple, Optional, Union
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
+from exo.download.hf.hf_helpers import (
+    download_repo_files, RepoProgressEvent, get_weight_map, 
+    get_allow_patterns, get_repo_root, fetch_file_list, 
+    get_local_snapshot_dir, get_file_download_percentage,
+    filter_repo_objects
+)
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.models import model_cards, get_repo
 from exo.models import model_cards, get_repo
+import aiohttp
+from aiofiles import os as aios
 
 
 
 
 class HFShardDownloader(ShardDownloader):
 class HFShardDownloader(ShardDownloader):
@@ -17,8 +24,13 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads: Dict[Shard, asyncio.Task] = {}
     self.active_downloads: Dict[Shard, asyncio.Task] = {}
     self.completed_downloads: Dict[Shard, Path] = {}
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+    self.current_shard: Optional[Shard] = None
+    self.current_repo_id: Optional[str] = None
+    self.revision: str = "main"
 
 
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    self.current_shard = shard
+    self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
     repo_name = get_repo(shard.model_id, inference_engine_name)
     repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
       return self.completed_downloads[shard]
@@ -77,3 +89,79 @@ class HFShardDownloader(ShardDownloader):
   @property
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return self._on_progress
     return self._on_progress
+
+  async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
+    if not self.current_shard or not self.current_repo_id:
+      if DEBUG >= 2:
+        print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
+      return None
+
+    try:
+      # If no snapshot directory exists, return None - no need to check remote files
+      snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
+      if not snapshot_dir:
+        if DEBUG >= 2:
+          print(f"No snapshot directory found for {self.current_repo_id}")
+        return None
+
+      # Get the weight map to know what files we need
+      weight_map = await get_weight_map(self.current_repo_id, self.revision)
+      if not weight_map:
+        if DEBUG >= 2:
+          print(f"No weight map found for {self.current_repo_id}")
+        return None
+
+      # Get all files needed for this shard
+      patterns = get_allow_patterns(weight_map, self.current_shard)
+
+      # Check download status for all relevant files
+      status = {}
+      total_bytes = 0
+      downloaded_bytes = 0
+
+      async with aiohttp.ClientSession() as session:
+        file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+        relevant_files = list(
+            filter_repo_objects(
+                file_list, allow_patterns=patterns, key=lambda x: x["path"]))
+
+        for file in relevant_files:
+          file_size = file["size"]
+          total_bytes += file_size
+
+          percentage = await get_file_download_percentage(
+              session,
+              self.current_repo_id,
+              self.revision,
+              file["path"],
+              snapshot_dir,
+          )
+          status[file["path"]] = percentage
+          downloaded_bytes += (file_size * (percentage / 100))
+
+        # Add overall progress weighted by file size
+        if total_bytes > 0:
+          status["overall"] = (downloaded_bytes / total_bytes) * 100
+        else:
+          status["overall"] = 0
+          
+        # Add total size in bytes
+        status["total_size"] = total_bytes
+        if status["overall"] != 100:
+          status["total_downloaded"] = downloaded_bytes
+        
+
+        if DEBUG >= 2:
+          print(f"Download calculation for {self.current_repo_id}:")
+          print(f"Total bytes: {total_bytes}")
+          print(f"Downloaded bytes: {downloaded_bytes}")
+          for file in relevant_files:
+            print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
+
+      return status
+
+    except Exception as e:
+      if DEBUG >= 2:
+        print(f"Error getting shard download status: {e}")
+        traceback.print_exc()
+      return None

+ 11 - 1
exo/download/shard_download.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Dict
 from pathlib import Path
 from pathlib import Path
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
@@ -26,6 +26,16 @@ class ShardDownloader(ABC):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
     pass
 
 
+  @abstractmethod
+  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+    """Get the download status of shards.
+    
+    Returns:
+        Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100),
+        or None if status cannot be determined
+    """
+    pass
+
 
 
 class NoopShardDownloader(ShardDownloader):
 class NoopShardDownloader(ShardDownloader):
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:

+ 202 - 29
exo/tinychat/index.css

@@ -21,8 +21,8 @@ main {
 .home {
 .home {
   width: 100%;
   width: 100%;
   height: 90%;
   height: 90%;
-
   margin-bottom: 10rem;
   margin-bottom: 10rem;
+  padding-top: 2rem;
 }
 }
 
 
 .title {
 .title {
@@ -129,8 +129,9 @@ main {
   flex-direction: column;
   flex-direction: column;
   gap: 1rem;
   gap: 1rem;
   align-items: center;
   align-items: center;
-  padding-top: 1rem;
+  padding-top: 4rem;
   padding-bottom: 11rem;
   padding-bottom: 11rem;
+  margin: 0 auto;
 }
 }
 
 
 .message {
 .message {
@@ -149,10 +150,17 @@ main {
   color: #000;
   color: #000;
 }
 }
 .download-progress {
 .download-progress {
-  margin-bottom: 12em;
+  position: fixed;
+  bottom: 11rem;
+  left: 50%;
+  transform: translateX(-50%);
+  margin-left: 125px;
+  width: 100%;
+  max-width: 1200px;
   overflow-y: auto;
   overflow-y: auto;
   min-height: 350px;
   min-height: 350px;
   padding: 2rem;
   padding: 2rem;
+  z-index: 998;
 }
 }
 .message > pre {
 .message > pre {
   white-space: pre-wrap;
   white-space: pre-wrap;
@@ -271,23 +279,24 @@ main {
 }
 }
 
 
 .input-container {
 .input-container {
-  position: absolute;
+  position: fixed;
   bottom: 0;
   bottom: 0;
-
-  /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(
-    0deg,
-    var(--primary-bg-color) 55%,
-    transparent 100%
-  );
-
-  width: 100%;
+  left: 250px;
+  width: calc(100% - 250px);
   max-width: 1200px;
   max-width: 1200px;
   display: flex;
   display: flex;
   flex-direction: column;
   flex-direction: column;
   justify-content: center;
   justify-content: center;
   align-items: center;
   align-items: center;
   z-index: 999;
   z-index: 999;
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
+  left: 50%;
+  transform: translateX(-50%);
+  margin-left: 125px;
 }
 }
 
 
 .input-performance {
 .input-performance {
@@ -372,22 +381,7 @@ p {
 }
 }
 
 
 .model-selector {
 .model-selector {
-  display: flex;
-  justify-content: center;
-  padding: 20px 0;
-}
-.model-selector select {
-  padding: 10px 20px;
-  font-size: 16px;
-  border: 1px solid #ccc;
-  border-radius: 5px;
-  background-color: #f8f8f8;
-  cursor: pointer;
-}
-.model-selector select:focus {
-  outline: none;
-  border-color: #007bff;
-  box-shadow: 0 0 0 2px rgba(0,123,255,.25);
+  display: none;
 }
 }
 
 
 /* Image upload button styles */
 /* Image upload button styles */
@@ -481,4 +475,183 @@ p {
 
 
 .clear-history-button i {
 .clear-history-button i {
   font-size: 14px;
   font-size: 14px;
+}
+
+/* Add new sidebar styles */
+.sidebar {
+  position: fixed;
+  left: 0;
+  top: 0;
+  bottom: 0;
+  width: 250px;
+  background-color: var(--secondary-color);
+  padding: 20px;
+  overflow-y: auto;
+  z-index: 1000;
+}
+
+.model-option {
+  padding: 12px;
+  margin: 8px 0;
+  border-radius: 8px;
+  background-color: var(--primary-bg-color);
+  cursor: pointer;
+  transition: all 0.2s ease;
+}
+
+.model-option:hover {
+  transform: translateX(5px);
+}
+
+.model-option.selected {
+  border-left: 3px solid var(--primary-color);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-name {
+  font-weight: bold;
+  margin-bottom: 4px;
+}
+
+.model-progress {
+  font-size: 0.9em;
+  color: var(--secondary-color-transparent);
+  display: flex;
+  flex-direction: column;
+  gap: 0.5rem;
+}
+
+.model-progress-info {
+  display: flex;
+  flex-direction: column;
+  gap: 0.5rem;
+}
+
+.model-progress i {
+  font-size: 0.9em;
+  color: var(--primary-color);
+}
+
+/* Adjust main content to accommodate sidebar */
+main {
+  margin-left: 250px;
+  width: calc(100% - 250px);
+}
+
+/* Add styles for the back button */
+.back-button {
+  position: fixed;
+  top: 1rem;
+  left: calc(250px + 1rem); /* Sidebar width + padding */
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  padding: 0.5rem 1rem;
+  border-radius: 8px;
+  border: none;
+  cursor: pointer;
+  display: flex;
+  align-items: center;
+  gap: 0.5rem;
+  z-index: 1000;
+  transition: all 0.2s ease;
+}
+
+.back-button:hover {
+  transform: translateX(-5px);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-info {
+  display: flex;
+  flex-direction: column;
+  gap: 4px;
+}
+
+.model-size {
+  font-size: 0.8em;
+  color: var(--secondary-color-transparent);
+  opacity: 0.8;
+}
+
+.model-header {
+    display: flex;
+    justify-content: space-between;
+    align-items: center;
+    margin-bottom: 4px;
+}
+
+.model-delete-button {
+    background: none;
+    border: none;
+    color: var(--red-color);
+    padding: 4px 8px;
+    cursor: pointer;
+    transition: all 0.2s ease;
+    opacity: 0.7;
+}
+
+.model-delete-button:hover {
+    opacity: 1;
+    transform: scale(1.1);
+}
+
+.model-option:hover .model-delete-button {
+    opacity: 1;
+}
+
+.loading-container {
+    display: flex;
+    flex-direction: column;
+    align-items: center;
+    gap: 10px;
+    padding: 20px;
+    color: var(--secondary-color-transparent);
+}
+
+.loading-container i {
+    font-size: 24px;
+}
+
+.loading-container span {
+    font-size: 14px;
+}
+
+/* Add this to your CSS */
+.fa-spin {
+    animation: fa-spin 2s infinite linear;
+}
+
+@keyframes fa-spin {
+    0% {
+        transform: rotate(0deg);
+    }
+    100% {
+        transform: rotate(360deg);
+    }
+}
+
+.model-download-button {
+  background: none;
+  border: none;
+  color: var(--primary-color);
+  padding: 4px 8px;
+  border-radius: 4px;
+  cursor: pointer;
+  transition: all 0.2s ease;
+  display: inline-flex;
+  align-items: center;
+  gap: 6px;
+  background-color: var(--primary-bg-color);
+  font-size: 0.9em;
+  width: fit-content;
+  align-self: flex-start;
+}
+
+.model-download-button:hover {
+  transform: scale(1.05);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-download-button i {
+  font-size: 0.9em;
 }
 }

+ 119 - 33
exo/tinychat/index.html

@@ -25,14 +25,73 @@
 </head>
 </head>
 <body>
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
 <main x-data="state" x-init="console.log(endpoint)">
-     <!-- Error Toast -->
-    <div x-show="errorMessage" x-transition.opacity class="toast">
+  <div class="sidebar">
+    <h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
+    
+    <!-- Loading indicator -->
+    <div class="loading-container" x-show="Object.keys(models).length === 0">
+        <i class="fas fa-spinner fa-spin"></i>
+        <span>Loading models...</span>
+    </div>
+    
+    <template x-for="(model, key) in models" :key="key">
+        <div class="model-option" 
+             :class="{ 'selected': cstate.selectedModel === key }"
+             @click="cstate.selectedModel = key">
+            <div class="model-header">
+                <div class="model-name" x-text="model.name"></div>
+                <button 
+                    @click.stop="deleteModel(key, model)"
+                    class="model-delete-button"
+                    x-show="model.download_percentage > 0">
+                    <i class="fas fa-trash"></i>
+                </button>
+            </div>
+            <div class="model-info">
+                <div class="model-progress">
+                    <template x-if="model.loading">
+                        <span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
+                    </template>
+                    <div class="model-progress-info">
+                        <template x-if="!model.loading && model.download_percentage != null">
+                            <span>
+                                <!-- Check if there's an active download for this model -->
+                                <template x-if="downloadProgress?.some(p => 
+                                    p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
+                                )">
+                                    <i class="fas fa-circle-notch fa-spin"></i>
+                                </template>
+                                <span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
+                            </span>
+                        </template>
+                        <template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
+                            <button 
+                                @click.stop="handleDownload(key)"
+                                class="model-download-button">
+                                <i class="fas fa-download"></i>
+                                <span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
+                            </button>
+                        </template>
+                    </div>
+                </div>
+                <template x-if="model.total_size">
+                    <div class="model-size" x-text="model.total_downloaded ? 
+                        `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` : 
+                        formatBytes(model.total_size)">
+                    </div>
+                </template>
+            </div>
+        </div>
+    </template>
+  </div> 
+    <!-- Error Toast -->
+    <div x-show="errorMessage !== null" x-transition.opacity class="toast">
         <div class="toast-header">
         <div class="toast-header">
-            <span class="toast-error-message" x-text="errorMessage.basic"></span>
+            <span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
             <div class="toast-header-buttons">
             <div class="toast-header-buttons">
                 <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
                 <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
                         class="toast-expand-button" 
                         class="toast-expand-button" 
-                        x-show="errorMessage.stack">
+                        x-show="errorMessage?.stack">
                     <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
                     <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
                 </button>
                 </button>
                 <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
                 <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
@@ -41,13 +100,10 @@
             </div>
             </div>
         </div>
         </div>
         <div class="toast-content" x-show="errorExpanded" x-transition>
         <div class="toast-content" x-show="errorExpanded" x-transition>
-            <span x-text="errorMessage.stack"></span>
+            <span x-text="errorMessage?.stack || ''"></span>
         </div>
         </div>
     </div>
     </div>
-<div class="model-selector">
-  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
-  </select>
-</div>
+
 <div @popstate.window="
 <div @popstate.window="
       if (home === 2) {
       if (home === 2) {
         home = -1;
         home = -1;
@@ -79,10 +135,8 @@
 <template x-for="_state in histories.toSorted((a, b) =&gt; b.time - a.time)">
 <template x-for="_state in histories.toSorted((a, b) =&gt; b.time - a.time)">
 <div @click="
 <div @click="
             cstate = _state;
             cstate = _state;
-            if (cstate) cstate.selectedModel = document.querySelector('.model-selector select').value
-            // updateTotalTokens(cstate.messages);
-            home = 1;
-            // ensure that going back in history will go back to home
+            if (!cstate.selectedModel) cstate.selectedModel = 'llama-3.2-1b';
+            home = 2;
             window.history.pushState({}, '', '/');
             window.history.pushState({}, '', '/');
           " @touchend="
           " @touchend="
             if (Math.abs($event.changedTouches[0].clientX - otx) &gt; trigger) removeHistory(_state);
             if (Math.abs($event.changedTouches[0].clientX - otx) &gt; trigger) removeHistory(_state);
@@ -108,6 +162,19 @@
 </template>
 </template>
 </div>
 </div>
 </div>
 </div>
+<button 
+    @click="
+        home = 0;
+        cstate = { time: null, messages: [], selectedModel: cstate.selectedModel };
+        time_till_first = 0;
+        tokens_per_second = 0;
+        total_tokens = 0;
+    " 
+    class="back-button"
+    x-show="home === 2">
+    <i class="fas fa-arrow-left"></i>
+    Back to Chats
+</button>
 <div class="messages" x-init="
 <div class="messages" x-init="
       $watch('cstate', value =&gt; {
       $watch('cstate', value =&gt; {
         $el.innerHTML = '';
         $el.innerHTML = '';
@@ -209,27 +276,46 @@
 <i class="fas fa-times"></i>
 <i class="fas fa-times"></i>
 </button>
 </button>
 </div>
 </div>
-<textarea :disabled="generating" :placeholder="generating ? 'Generating...' : 'Say something'" @input="
-            home = (home === 0) ? 1 : home
-            if (cstate.messages.length === 0 &amp;&amp; $el.value === '') home = -1;
+<textarea 
+    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))" 
+    :placeholder="
+        generating ? 'Generating...' : 
+        (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete)) ? 'Download in progress...' :
+        'Say something'
+    "
+    @input="
+        home = (home === 0) ? 1 : home
+        if (cstate.messages.length === 0 && $el.value === '') home = -1;
 
 
-            if ($el.value !== '') {
-              const messages = [...cstate.messages];
-              messages.push({ role: 'user', content: $el.value });
-              // updateTotalTokens(messages);
-            } else {
-              if (cstate.messages.length === 0) total_tokens = 0;
-              // else updateTotalTokens(cstate.messages);
-            }
-          " @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)" autofocus="" class="input-form" id="input-form" rows="1" x-autosize="" x-effect="
-            console.log(generating);
-            if (!generating) $nextTick(() =&gt; {
-              $el.focus();
-              setTimeout(() =&gt; $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
-            });
-          " x-ref="inputForm"></textarea>
-<button :disabled="generating" @click="await handleSend()" class="input-button">
-<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
+        if ($el.value !== '') {
+            const messages = [...cstate.messages];
+            messages.push({ role: 'user', content: $el.value });
+            // updateTotalTokens(messages);
+        } else {
+            if (cstate.messages.length === 0) total_tokens = 0;
+            // else updateTotalTokens(cstate.messages);
+        }
+    "
+    @keydown.enter="await handleEnter($event)"
+    @keydown.escape.window="$focus.focus($el)"
+    autofocus=""
+    class="input-form"
+    id="input-form"
+    rows="1"
+    x-autosize=""
+    x-effect="
+        console.log(generating);
+        if (!generating) $nextTick(() => {
+            $el.focus();
+            setTimeout(() => $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
+        });
+    "
+    x-ref="inputForm"></textarea>
+<button 
+    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))" 
+    @click="await handleSend()" 
+    class="input-button">
+    <i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
 </button>
 </button>
 </div>
 </div>
 </div>
 </div>

+ 173 - 96
exo/tinychat/index.js

@@ -13,6 +13,8 @@ document.addEventListener("alpine:init", () => {
     home: 0,
     home: 0,
     generating: false,
     generating: false,
     endpoint: `${window.location.origin}/v1`,
     endpoint: `${window.location.origin}/v1`,
+    
+    // Initialize error message structure
     errorMessage: null,
     errorMessage: null,
     errorExpanded: false,
     errorExpanded: false,
     errorTimeout: null,
     errorTimeout: null,
@@ -32,12 +34,81 @@ document.addEventListener("alpine:init", () => {
     // Pending message storage
     // Pending message storage
     pendingMessage: null,
     pendingMessage: null,
 
 
+    modelPoolInterval: null,
+
+    // Add models state alongside existing state
+    models: {},
+
     init() {
     init() {
       // Clean up any pending messages
       // Clean up any pending messages
       localStorage.removeItem("pendingMessage");
       localStorage.removeItem("pendingMessage");
 
 
+      // Get initial model list
+      this.fetchInitialModels();
+
       // Start polling for download progress
       // Start polling for download progress
       this.startDownloadProgressPolling();
       this.startDownloadProgressPolling();
+      
+      // Start model polling with the new pattern
+      this.startModelPolling();
+    },
+
+    async fetchInitialModels() {
+      try {
+        const response = await fetch(`${window.location.origin}/initial_models`);
+        if (response.ok) {
+          const initialModels = await response.json();
+          this.models = initialModels;
+        }
+      } catch (error) {
+        console.error('Error fetching initial models:', error);
+      }
+    },
+
+    async startModelPolling() {
+      while (true) {
+        try {
+          await this.populateSelector();
+          // Wait 5 seconds before next poll
+          await new Promise(resolve => setTimeout(resolve, 5000));
+        } catch (error) {
+          console.error('Model polling error:', error);
+          // If there's an error, wait before retrying
+          await new Promise(resolve => setTimeout(resolve, 5000));
+        }
+      }
+    },
+
+    async populateSelector() {
+      return new Promise((resolve, reject) => {
+        const evtSource = new EventSource(`${window.location.origin}/modelpool`);
+        
+        evtSource.onmessage = (event) => {
+          if (event.data === "[DONE]") {
+            evtSource.close();
+            resolve();
+            return;
+          }
+          
+          const modelData = JSON.parse(event.data);
+          // Update existing model data while preserving other properties
+          Object.entries(modelData).forEach(([modelName, data]) => {
+            if (this.models[modelName]) {
+              this.models[modelName] = {
+                ...this.models[modelName],
+                ...data,
+                loading: false
+              };
+            }
+          });
+        };
+        
+        evtSource.onerror = (error) => {
+          console.error('EventSource failed:', error);
+          evtSource.close();
+          reject(error);
+        };
+      });
     },
     },
 
 
     removeHistory(cstate) {
     removeHistory(cstate) {
@@ -74,56 +145,6 @@ document.addEventListener("alpine:init", () => {
       return `${s}s`;
       return `${s}s`;
     },
     },
 
 
-    async populateSelector() {
-      try {
-        const response = await fetch(`${window.location.origin}/modelpool`);
-        const responseText = await response.text(); // Get raw response text first
-        
-        if (!response.ok) {
-          throw new Error(`HTTP error! status: ${response.status}`);
-        }
-        
-        // Try to parse the response text
-        let responseJson;
-        try {
-          responseJson = JSON.parse(responseText);
-        } catch (parseError) {
-          console.error('Failed to parse JSON:', parseError);
-          throw new Error(`Invalid JSON response: ${responseText}`);
-        }
-
-        const sel = document.querySelector(".model-select");
-        if (!sel) {
-          throw new Error("Could not find model selector element");
-        }
-
-        // Clear the current options and add new ones
-        sel.innerHTML = '';
-          
-        const modelDict = responseJson["model pool"];
-        if (!modelDict) {
-          throw new Error("Response missing 'model pool' property");
-        }
-
-        Object.entries(modelDict).forEach(([key, value]) => {
-          const opt = document.createElement("option");
-          opt.value = key;
-          opt.textContent = value;
-          sel.appendChild(opt);
-        });
-
-        // Set initial value to the first model
-        const firstKey = Object.keys(modelDict)[0];
-        if (firstKey) {
-          sel.value = firstKey;
-          this.cstate.selectedModel = firstKey;
-        }
-      } catch (error) {
-        console.error("Error populating model selector:", error);
-        this.errorMessage = `Failed to load models: ${error.message}`;
-      }
-    },
-
     async handleImageUpload(event) {
     async handleImageUpload(event) {
       const file = event.target.files[0];
       const file = event.target.files[0];
       if (file) {
       if (file) {
@@ -169,29 +190,7 @@ document.addEventListener("alpine:init", () => {
         this.processMessage(value);
         this.processMessage(value);
       } catch (error) {
       } catch (error) {
         console.error('error', error);
         console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
+        this.setError(error);
         this.generating = false;
         this.generating = false;
       }
       }
     },
     },
@@ -309,29 +308,7 @@ document.addEventListener("alpine:init", () => {
         }
         }
       } catch (error) {
       } catch (error) {
         console.error('error', error);
         console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
+        this.setError(error);
       } finally {
       } finally {
         this.generating = false;
         this.generating = false;
       }
       }
@@ -467,6 +444,106 @@ document.addEventListener("alpine:init", () => {
         this.fetchDownloadProgress();
         this.fetchDownloadProgress();
       }, 1000); // Poll every second
       }, 1000); // Poll every second
     },
     },
+
+    // Add a helper method to set errors consistently
+    setError(error) {
+      this.errorMessage = {
+        basic: error.message || "An unknown error occurred",
+        stack: error.stack || ""
+      };
+      this.errorExpanded = false;
+      
+      if (this.errorTimeout) {
+        clearTimeout(this.errorTimeout);
+      }
+
+      if (!this.errorExpanded) {
+        this.errorTimeout = setTimeout(() => {
+          this.errorMessage = null;
+          this.errorExpanded = false;
+        }, 30 * 1000);
+      }
+    },
+
+    async deleteModel(modelName, model) {
+      const downloadedSize = model.total_downloaded || 0;
+      const sizeMessage = downloadedSize > 0 ? 
+        `This will free up ${this.formatBytes(downloadedSize)} of space.` :
+        'This will remove any partially downloaded files.';
+      
+      if (!confirm(`Are you sure you want to delete ${model.name}? ${sizeMessage}`)) {
+        return;
+      }
+
+      try {
+        const response = await fetch(`${window.location.origin}/models/${modelName}`, {
+          method: 'DELETE',
+          headers: {
+            'Content-Type': 'application/json'
+          }
+        });
+
+        const data = await response.json();
+        
+        if (!response.ok) {
+          throw new Error(data.detail || 'Failed to delete model');
+        }
+
+        // Update the model status in the UI
+        if (this.models[modelName]) {
+          this.models[modelName].downloaded = false;
+          this.models[modelName].download_percentage = 0;
+          this.models[modelName].total_downloaded = 0;
+        }
+
+        // If this was the selected model, switch to a different one
+        if (this.cstate.selectedModel === modelName) {
+          const availableModel = Object.keys(this.models).find(key => this.models[key].downloaded);
+          this.cstate.selectedModel = availableModel || 'llama-3.2-1b';
+        }
+
+        // Show success message
+        console.log(`Model deleted successfully from: ${data.path}`);
+
+        // Refresh the model list
+        await this.populateSelector();
+      } catch (error) {
+        console.error('Error deleting model:', error);
+        this.setError(error.message || 'Failed to delete model');
+      }
+    },
+
+    async handleDownload(modelName) {
+      try {
+        const response = await fetch(`${window.location.origin}/download`, {
+          method: 'POST',
+          headers: {
+            'Content-Type': 'application/json'
+          },
+          body: JSON.stringify({
+            model: modelName
+          })
+        });
+
+        const data = await response.json();
+
+        if (!response.ok) {
+          throw new Error(data.error || 'Failed to start download');
+        }
+
+        // Update the model's status immediately when download starts
+        if (this.models[modelName]) {
+          this.models[modelName] = {
+            ...this.models[modelName],
+            loading: true
+          };
+        }
+
+      } catch (error) {
+        console.error('Error starting download:', error);
+        this.setError(error);
+      }
+    }
   }));
   }));
 });
 });