Browse Source

Merge pull request #456 from cadenmackenzie/downloadedModelsV2

Show downloaded models, improve error handling, ability to delete models, side bar with more detail, button to go back to chat history
Alex Cheema 6 months ago
parent
commit
3adaba6ab8

+ 123 - 9
exo/api/chatgpt_api.py

@@ -2,6 +2,7 @@ import uuid
 import time
 import asyncio
 import json
+import os
 from pathlib import Path
 from transformers import AutoTokenizer
 from typing import List, Literal, Union, Dict
@@ -14,10 +15,12 @@ from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
-from exo.apputil import create_animation_mp4
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable, Optional
-import tempfile
+from exo.download.hf.hf_shard_download import HFShardDownloader
+import shutil
+from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
+from exo.apputil import create_animation_mp4
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -175,6 +178,8 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
+    cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
+    cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
     cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
     cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
 
@@ -216,12 +221,56 @@ class ChatGPTAPI:
     return web.json_response({"status": "ok"})
 
   async def handle_model_support(self, request):
-    return web.json_response({
-      "model pool": {
-        model_name: pretty_name.get(model_name, model_name)
-        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
-      }
-    })
+    try:
+        response = web.StreamResponse(
+            status=200,
+            reason='OK',
+            headers={
+                'Content-Type': 'text/event-stream',
+                'Cache-Control': 'no-cache',
+                'Connection': 'keep-alive',
+            }
+        )
+        await response.prepare(request)
+        
+        for model_name, pretty in pretty_name.items():
+            if model_name in model_cards:
+                model_info = model_cards[model_name]
+                
+                if self.inference_engine_classname in model_info.get("repo", {}):
+                    shard = build_base_shard(model_name, self.inference_engine_classname)
+                    if shard:
+                        downloader = HFShardDownloader(quick_check=True)
+                        downloader.current_shard = shard
+                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+                        status = await downloader.get_shard_download_status()
+                        
+                        download_percentage = status.get("overall") if status else None
+                        total_size = status.get("total_size") if status else None
+                        total_downloaded = status.get("total_downloaded") if status else False
+                        
+                        model_data = {
+                            model_name: {
+                                "name": pretty,
+                                "downloaded": download_percentage == 100 if download_percentage is not None else False,
+                                "download_percentage": download_percentage,
+                                "total_size": total_size,
+                                "total_downloaded": total_downloaded
+                            }
+                        }
+                        
+                        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
+        
+        await response.write(b"data: [DONE]\n\n")
+        return response
+        
+    except Exception as e:
+        print(f"Error in handle_model_support: {str(e)}")
+        traceback.print_exc()
+        return web.json_response(
+            {"detail": f"Server error: {str(e)}"}, 
+            status=500
+        )
 
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
@@ -372,6 +421,71 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
+  async def handle_delete_model(self, request):
+    try:
+      model_name = request.match_info.get('model_name')
+      if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
+      
+      if not model_name or model_name not in model_cards:
+        return web.json_response(
+          {"detail": f"Invalid model name: {model_name}"}, 
+          status=400
+          )
+
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard:
+        return web.json_response(
+          {"detail": "Could not build shard for model"}, 
+          status=400
+        )
+
+      repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+      if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
+      
+      # Get the HF cache directory using the helper function
+      hf_home = get_hf_home()
+      cache_dir = get_repo_root(repo_id)
+      
+      if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
+      
+      if os.path.exists(cache_dir):
+        if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
+        try:
+          shutil.rmtree(cache_dir)
+          return web.json_response({
+            "status": "success", 
+            "message": f"Model {model_name} deleted successfully",
+            "path": str(cache_dir)
+          })
+        except Exception as e:
+          return web.json_response({
+            "detail": f"Failed to delete model files: {str(e)}"
+          }, status=500)
+      else:
+        return web.json_response({
+          "detail": f"Model files not found at {cache_dir}"
+        }, status=404)
+            
+    except Exception as e:
+        print(f"Error in handle_delete_model: {str(e)}")
+        traceback.print_exc()
+        return web.json_response({
+            "detail": f"Server error: {str(e)}"
+        }, status=500)
+
+  async def handle_get_initial_models(self, request):
+    model_data = {}
+    for model_name, pretty in pretty_name.items():
+        model_data[model_name] = {
+            "name": pretty,
+            "downloaded": None,  # Initially unknown
+            "download_percentage": None,  # Change from 0 to null
+            "total_size": None,
+            "total_downloaded": None,
+            "loading": True  # Add loading state
+        }
+    return web.json_response(model_data)
+
   async def handle_create_animation(self, request):
     try:
       data = await request.json()

+ 62 - 2
exo/download/hf/hf_helpers.py

@@ -166,10 +166,18 @@ async def download_file(
     downloaded_size = local_file_size
     downloaded_this_session = 0
     mode = 'ab' if use_range_request else 'wb'
-    if downloaded_size == total_size:
+    percentage = await get_file_download_percentage(
+      session,
+      repo_id,
+      revision,
+      file_path,
+      Path(save_directory)
+    )
+    
+    if percentage == 100:
       if DEBUG >= 2: print(f"File already downloaded: {file_path}")
       if progress_callback:
-        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
+        await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
       return
 
     if response.status == 200:
@@ -432,6 +440,57 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
 
+async def get_file_download_percentage(
+    session: aiohttp.ClientSession,
+    repo_id: str,
+    revision: str,
+    file_path: str,
+    snapshot_dir: Path,
+) -> float:
+  """
+    Calculate the download percentage for a file by comparing local and remote sizes.
+    """
+  try:
+    local_path = snapshot_dir / file_path
+    if not await aios.path.exists(local_path):
+      return 0
+
+    # Get local file size first
+    local_size = await aios.path.getsize(local_path)
+    if local_size == 0:
+      return 0
+
+    # Check remote size
+    base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
+    url = urljoin(base_url, file_path)
+    headers = await get_auth_headers()
+
+    # Use HEAD request with redirect following for all files
+    async with session.head(url, headers=headers, allow_redirects=True) as response:
+      if response.status != 200:
+        if DEBUG >= 2:
+          print(f"Failed to get remote file info for {file_path}: {response.status}")
+        return 0
+
+      remote_size = int(response.headers.get('Content-Length', 0))
+
+      if remote_size == 0:
+        if DEBUG >= 2:
+          print(f"Remote size is 0 for {file_path}")
+        return 0
+
+      # Only return 100% if sizes match exactly
+      if local_size == remote_size:
+        return 100.0
+
+      # Calculate percentage based on sizes
+      return (local_size / remote_size) * 100 if remote_size > 0 else 0
+
+  except Exception as e:
+    if DEBUG >= 2:
+      print(f"Error checking file download status for {file_path}: {e}")
+    return 0
+
 async def has_hf_home_read_access() -> bool:
   hf_home = get_hf_home()
   try: return await aios.access(hf_home, os.R_OK)
@@ -441,3 +500,4 @@ async def has_hf_home_write_access() -> bool:
   hf_home = get_hf_home()
   try: return await aios.access(hf_home, os.W_OK)
   except OSError: return False
+

+ 90 - 2
exo/download/hf/hf_shard_download.py

@@ -1,13 +1,20 @@
 import asyncio
 import traceback
 from pathlib import Path
-from typing import Dict, List, Tuple
+from typing import Dict, List, Tuple, Optional, Union
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
-from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
+from exo.download.hf.hf_helpers import (
+    download_repo_files, RepoProgressEvent, get_weight_map, 
+    get_allow_patterns, get_repo_root, fetch_file_list, 
+    get_local_snapshot_dir, get_file_download_percentage,
+    filter_repo_objects
+)
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.models import model_cards, get_repo
+import aiohttp
+from aiofiles import os as aios
 
 
 class HFShardDownloader(ShardDownloader):
@@ -17,8 +24,13 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads: Dict[Shard, asyncio.Task] = {}
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
+    self.current_shard: Optional[Shard] = None
+    self.current_repo_id: Optional[str] = None
+    self.revision: str = "main"
 
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    self.current_shard = shard
+    self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
     repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
@@ -77,3 +89,79 @@ class HFShardDownloader(ShardDownloader):
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return self._on_progress
+
+  async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
+    if not self.current_shard or not self.current_repo_id:
+      if DEBUG >= 2:
+        print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
+      return None
+
+    try:
+      # If no snapshot directory exists, return None - no need to check remote files
+      snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
+      if not snapshot_dir:
+        if DEBUG >= 2:
+          print(f"No snapshot directory found for {self.current_repo_id}")
+        return None
+
+      # Get the weight map to know what files we need
+      weight_map = await get_weight_map(self.current_repo_id, self.revision)
+      if not weight_map:
+        if DEBUG >= 2:
+          print(f"No weight map found for {self.current_repo_id}")
+        return None
+
+      # Get all files needed for this shard
+      patterns = get_allow_patterns(weight_map, self.current_shard)
+
+      # Check download status for all relevant files
+      status = {}
+      total_bytes = 0
+      downloaded_bytes = 0
+
+      async with aiohttp.ClientSession() as session:
+        file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
+        relevant_files = list(
+            filter_repo_objects(
+                file_list, allow_patterns=patterns, key=lambda x: x["path"]))
+
+        for file in relevant_files:
+          file_size = file["size"]
+          total_bytes += file_size
+
+          percentage = await get_file_download_percentage(
+              session,
+              self.current_repo_id,
+              self.revision,
+              file["path"],
+              snapshot_dir,
+          )
+          status[file["path"]] = percentage
+          downloaded_bytes += (file_size * (percentage / 100))
+
+        # Add overall progress weighted by file size
+        if total_bytes > 0:
+          status["overall"] = (downloaded_bytes / total_bytes) * 100
+        else:
+          status["overall"] = 0
+          
+        # Add total size in bytes
+        status["total_size"] = total_bytes
+        if status["overall"] != 100:
+          status["total_downloaded"] = downloaded_bytes
+        
+
+        if DEBUG >= 2:
+          print(f"Download calculation for {self.current_repo_id}:")
+          print(f"Total bytes: {total_bytes}")
+          print(f"Downloaded bytes: {downloaded_bytes}")
+          for file in relevant_files:
+            print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
+
+      return status
+
+    except Exception as e:
+      if DEBUG >= 2:
+        print(f"Error getting shard download status: {e}")
+        traceback.print_exc()
+      return None

+ 11 - 1
exo/download/shard_download.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Dict
 from pathlib import Path
 from exo.inference.shard import Shard
 from exo.download.download_progress import RepoProgressEvent
@@ -26,6 +26,16 @@ class ShardDownloader(ABC):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
 
+  @abstractmethod
+  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+    """Get the download status of shards.
+    
+    Returns:
+        Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100),
+        or None if status cannot be determined
+    """
+    pass
+
 
 class NoopShardDownloader(ShardDownloader):
   async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:

+ 202 - 29
exo/tinychat/index.css

@@ -21,8 +21,8 @@ main {
 .home {
   width: 100%;
   height: 90%;
-
   margin-bottom: 10rem;
+  padding-top: 2rem;
 }
 
 .title {
@@ -129,8 +129,9 @@ main {
   flex-direction: column;
   gap: 1rem;
   align-items: center;
-  padding-top: 1rem;
+  padding-top: 4rem;
   padding-bottom: 11rem;
+  margin: 0 auto;
 }
 
 .message {
@@ -149,10 +150,17 @@ main {
   color: #000;
 }
 .download-progress {
-  margin-bottom: 12em;
+  position: fixed;
+  bottom: 11rem;
+  left: 50%;
+  transform: translateX(-50%);
+  margin-left: 125px;
+  width: 100%;
+  max-width: 1200px;
   overflow-y: auto;
   min-height: 350px;
   padding: 2rem;
+  z-index: 998;
 }
 .message > pre {
   white-space: pre-wrap;
@@ -271,23 +279,24 @@ main {
 }
 
 .input-container {
-  position: absolute;
+  position: fixed;
   bottom: 0;
-
-  /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(
-    0deg,
-    var(--primary-bg-color) 55%,
-    transparent 100%
-  );
-
-  width: 100%;
+  left: 250px;
+  width: calc(100% - 250px);
   max-width: 1200px;
   display: flex;
   flex-direction: column;
   justify-content: center;
   align-items: center;
   z-index: 999;
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
+  left: 50%;
+  transform: translateX(-50%);
+  margin-left: 125px;
 }
 
 .input-performance {
@@ -372,22 +381,7 @@ p {
 }
 
 .model-selector {
-  display: flex;
-  justify-content: center;
-  padding: 20px 0;
-}
-.model-selector select {
-  padding: 10px 20px;
-  font-size: 16px;
-  border: 1px solid #ccc;
-  border-radius: 5px;
-  background-color: #f8f8f8;
-  cursor: pointer;
-}
-.model-selector select:focus {
-  outline: none;
-  border-color: #007bff;
-  box-shadow: 0 0 0 2px rgba(0,123,255,.25);
+  display: none;
 }
 
 /* Image upload button styles */
@@ -481,4 +475,183 @@ p {
 
 .clear-history-button i {
   font-size: 14px;
+}
+
+/* Add new sidebar styles */
+.sidebar {
+  position: fixed;
+  left: 0;
+  top: 0;
+  bottom: 0;
+  width: 250px;
+  background-color: var(--secondary-color);
+  padding: 20px;
+  overflow-y: auto;
+  z-index: 1000;
+}
+
+.model-option {
+  padding: 12px;
+  margin: 8px 0;
+  border-radius: 8px;
+  background-color: var(--primary-bg-color);
+  cursor: pointer;
+  transition: all 0.2s ease;
+}
+
+.model-option:hover {
+  transform: translateX(5px);
+}
+
+.model-option.selected {
+  border-left: 3px solid var(--primary-color);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-name {
+  font-weight: bold;
+  margin-bottom: 4px;
+}
+
+.model-progress {
+  font-size: 0.9em;
+  color: var(--secondary-color-transparent);
+  display: flex;
+  flex-direction: column;
+  gap: 0.5rem;
+}
+
+.model-progress-info {
+  display: flex;
+  flex-direction: column;
+  gap: 0.5rem;
+}
+
+.model-progress i {
+  font-size: 0.9em;
+  color: var(--primary-color);
+}
+
+/* Adjust main content to accommodate sidebar */
+main {
+  margin-left: 250px;
+  width: calc(100% - 250px);
+}
+
+/* Add styles for the back button */
+.back-button {
+  position: fixed;
+  top: 1rem;
+  left: calc(250px + 1rem); /* Sidebar width + padding */
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  padding: 0.5rem 1rem;
+  border-radius: 8px;
+  border: none;
+  cursor: pointer;
+  display: flex;
+  align-items: center;
+  gap: 0.5rem;
+  z-index: 1000;
+  transition: all 0.2s ease;
+}
+
+.back-button:hover {
+  transform: translateX(-5px);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-info {
+  display: flex;
+  flex-direction: column;
+  gap: 4px;
+}
+
+.model-size {
+  font-size: 0.8em;
+  color: var(--secondary-color-transparent);
+  opacity: 0.8;
+}
+
+.model-header {
+    display: flex;
+    justify-content: space-between;
+    align-items: center;
+    margin-bottom: 4px;
+}
+
+.model-delete-button {
+    background: none;
+    border: none;
+    color: var(--red-color);
+    padding: 4px 8px;
+    cursor: pointer;
+    transition: all 0.2s ease;
+    opacity: 0.7;
+}
+
+.model-delete-button:hover {
+    opacity: 1;
+    transform: scale(1.1);
+}
+
+.model-option:hover .model-delete-button {
+    opacity: 1;
+}
+
+.loading-container {
+    display: flex;
+    flex-direction: column;
+    align-items: center;
+    gap: 10px;
+    padding: 20px;
+    color: var(--secondary-color-transparent);
+}
+
+.loading-container i {
+    font-size: 24px;
+}
+
+.loading-container span {
+    font-size: 14px;
+}
+
+/* Add this to your CSS */
+.fa-spin {
+    animation: fa-spin 2s infinite linear;
+}
+
+@keyframes fa-spin {
+    0% {
+        transform: rotate(0deg);
+    }
+    100% {
+        transform: rotate(360deg);
+    }
+}
+
+.model-download-button {
+  background: none;
+  border: none;
+  color: var(--primary-color);
+  padding: 4px 8px;
+  border-radius: 4px;
+  cursor: pointer;
+  transition: all 0.2s ease;
+  display: inline-flex;
+  align-items: center;
+  gap: 6px;
+  background-color: var(--primary-bg-color);
+  font-size: 0.9em;
+  width: fit-content;
+  align-self: flex-start;
+}
+
+.model-download-button:hover {
+  transform: scale(1.05);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-download-button i {
+  font-size: 0.9em;
 }

+ 119 - 33
exo/tinychat/index.html

@@ -25,14 +25,73 @@
 </head>
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
-     <!-- Error Toast -->
-    <div x-show="errorMessage" x-transition.opacity class="toast">
+  <div class="sidebar">
+    <h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
+    
+    <!-- Loading indicator -->
+    <div class="loading-container" x-show="Object.keys(models).length === 0">
+        <i class="fas fa-spinner fa-spin"></i>
+        <span>Loading models...</span>
+    </div>
+    
+    <template x-for="(model, key) in models" :key="key">
+        <div class="model-option" 
+             :class="{ 'selected': cstate.selectedModel === key }"
+             @click="cstate.selectedModel = key">
+            <div class="model-header">
+                <div class="model-name" x-text="model.name"></div>
+                <button 
+                    @click.stop="deleteModel(key, model)"
+                    class="model-delete-button"
+                    x-show="model.download_percentage > 0">
+                    <i class="fas fa-trash"></i>
+                </button>
+            </div>
+            <div class="model-info">
+                <div class="model-progress">
+                    <template x-if="model.loading">
+                        <span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
+                    </template>
+                    <div class="model-progress-info">
+                        <template x-if="!model.loading && model.download_percentage != null">
+                            <span>
+                                <!-- Check if there's an active download for this model -->
+                                <template x-if="downloadProgress?.some(p => 
+                                    p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
+                                )">
+                                    <i class="fas fa-circle-notch fa-spin"></i>
+                                </template>
+                                <span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
+                            </span>
+                        </template>
+                        <template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
+                            <button 
+                                @click.stop="handleDownload(key)"
+                                class="model-download-button">
+                                <i class="fas fa-download"></i>
+                                <span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
+                            </button>
+                        </template>
+                    </div>
+                </div>
+                <template x-if="model.total_size">
+                    <div class="model-size" x-text="model.total_downloaded ? 
+                        `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` : 
+                        formatBytes(model.total_size)">
+                    </div>
+                </template>
+            </div>
+        </div>
+    </template>
+  </div> 
+    <!-- Error Toast -->
+    <div x-show="errorMessage !== null" x-transition.opacity class="toast">
         <div class="toast-header">
-            <span class="toast-error-message" x-text="errorMessage.basic"></span>
+            <span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
             <div class="toast-header-buttons">
                 <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
                         class="toast-expand-button" 
-                        x-show="errorMessage.stack">
+                        x-show="errorMessage?.stack">
                     <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
                 </button>
                 <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
@@ -41,13 +100,10 @@
             </div>
         </div>
         <div class="toast-content" x-show="errorExpanded" x-transition>
-            <span x-text="errorMessage.stack"></span>
+            <span x-text="errorMessage?.stack || ''"></span>
         </div>
     </div>
-<div class="model-selector">
-  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
-  </select>
-</div>
+
 <div @popstate.window="
       if (home === 2) {
         home = -1;
@@ -79,10 +135,8 @@
 <template x-for="_state in histories.toSorted((a, b) =&gt; b.time - a.time)">
 <div @click="
             cstate = _state;
-            if (cstate) cstate.selectedModel = document.querySelector('.model-selector select').value
-            // updateTotalTokens(cstate.messages);
-            home = 1;
-            // ensure that going back in history will go back to home
+            if (!cstate.selectedModel) cstate.selectedModel = 'llama-3.2-1b';
+            home = 2;
             window.history.pushState({}, '', '/');
           " @touchend="
             if (Math.abs($event.changedTouches[0].clientX - otx) &gt; trigger) removeHistory(_state);
@@ -108,6 +162,19 @@
 </template>
 </div>
 </div>
+<button 
+    @click="
+        home = 0;
+        cstate = { time: null, messages: [], selectedModel: cstate.selectedModel };
+        time_till_first = 0;
+        tokens_per_second = 0;
+        total_tokens = 0;
+    " 
+    class="back-button"
+    x-show="home === 2">
+    <i class="fas fa-arrow-left"></i>
+    Back to Chats
+</button>
 <div class="messages" x-init="
       $watch('cstate', value =&gt; {
         $el.innerHTML = '';
@@ -209,27 +276,46 @@
 <i class="fas fa-times"></i>
 </button>
 </div>
-<textarea :disabled="generating" :placeholder="generating ? 'Generating...' : 'Say something'" @input="
-            home = (home === 0) ? 1 : home
-            if (cstate.messages.length === 0 &amp;&amp; $el.value === '') home = -1;
+<textarea 
+    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))" 
+    :placeholder="
+        generating ? 'Generating...' : 
+        (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete)) ? 'Download in progress...' :
+        'Say something'
+    "
+    @input="
+        home = (home === 0) ? 1 : home
+        if (cstate.messages.length === 0 && $el.value === '') home = -1;
 
-            if ($el.value !== '') {
-              const messages = [...cstate.messages];
-              messages.push({ role: 'user', content: $el.value });
-              // updateTotalTokens(messages);
-            } else {
-              if (cstate.messages.length === 0) total_tokens = 0;
-              // else updateTotalTokens(cstate.messages);
-            }
-          " @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)" autofocus="" class="input-form" id="input-form" rows="1" x-autosize="" x-effect="
-            console.log(generating);
-            if (!generating) $nextTick(() =&gt; {
-              $el.focus();
-              setTimeout(() =&gt; $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
-            });
-          " x-ref="inputForm"></textarea>
-<button :disabled="generating" @click="await handleSend()" class="input-button">
-<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
+        if ($el.value !== '') {
+            const messages = [...cstate.messages];
+            messages.push({ role: 'user', content: $el.value });
+            // updateTotalTokens(messages);
+        } else {
+            if (cstate.messages.length === 0) total_tokens = 0;
+            // else updateTotalTokens(cstate.messages);
+        }
+    "
+    @keydown.enter="await handleEnter($event)"
+    @keydown.escape.window="$focus.focus($el)"
+    autofocus=""
+    class="input-form"
+    id="input-form"
+    rows="1"
+    x-autosize=""
+    x-effect="
+        console.log(generating);
+        if (!generating) $nextTick(() => {
+            $el.focus();
+            setTimeout(() => $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
+        });
+    "
+    x-ref="inputForm"></textarea>
+<button 
+    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))" 
+    @click="await handleSend()" 
+    class="input-button">
+    <i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
 </button>
 </div>
 </div>

+ 173 - 96
exo/tinychat/index.js

@@ -13,6 +13,8 @@ document.addEventListener("alpine:init", () => {
     home: 0,
     generating: false,
     endpoint: `${window.location.origin}/v1`,
+    
+    // Initialize error message structure
     errorMessage: null,
     errorExpanded: false,
     errorTimeout: null,
@@ -32,12 +34,81 @@ document.addEventListener("alpine:init", () => {
     // Pending message storage
     pendingMessage: null,
 
+    modelPoolInterval: null,
+
+    // Add models state alongside existing state
+    models: {},
+
     init() {
       // Clean up any pending messages
       localStorage.removeItem("pendingMessage");
 
+      // Get initial model list
+      this.fetchInitialModels();
+
       // Start polling for download progress
       this.startDownloadProgressPolling();
+      
+      // Start model polling with the new pattern
+      this.startModelPolling();
+    },
+
+    async fetchInitialModels() {
+      try {
+        const response = await fetch(`${window.location.origin}/initial_models`);
+        if (response.ok) {
+          const initialModels = await response.json();
+          this.models = initialModels;
+        }
+      } catch (error) {
+        console.error('Error fetching initial models:', error);
+      }
+    },
+
+    async startModelPolling() {
+      while (true) {
+        try {
+          await this.populateSelector();
+          // Wait 5 seconds before next poll
+          await new Promise(resolve => setTimeout(resolve, 5000));
+        } catch (error) {
+          console.error('Model polling error:', error);
+          // If there's an error, wait before retrying
+          await new Promise(resolve => setTimeout(resolve, 5000));
+        }
+      }
+    },
+
+    async populateSelector() {
+      return new Promise((resolve, reject) => {
+        const evtSource = new EventSource(`${window.location.origin}/modelpool`);
+        
+        evtSource.onmessage = (event) => {
+          if (event.data === "[DONE]") {
+            evtSource.close();
+            resolve();
+            return;
+          }
+          
+          const modelData = JSON.parse(event.data);
+          // Update existing model data while preserving other properties
+          Object.entries(modelData).forEach(([modelName, data]) => {
+            if (this.models[modelName]) {
+              this.models[modelName] = {
+                ...this.models[modelName],
+                ...data,
+                loading: false
+              };
+            }
+          });
+        };
+        
+        evtSource.onerror = (error) => {
+          console.error('EventSource failed:', error);
+          evtSource.close();
+          reject(error);
+        };
+      });
     },
 
     removeHistory(cstate) {
@@ -74,56 +145,6 @@ document.addEventListener("alpine:init", () => {
       return `${s}s`;
     },
 
-    async populateSelector() {
-      try {
-        const response = await fetch(`${window.location.origin}/modelpool`);
-        const responseText = await response.text(); // Get raw response text first
-        
-        if (!response.ok) {
-          throw new Error(`HTTP error! status: ${response.status}`);
-        }
-        
-        // Try to parse the response text
-        let responseJson;
-        try {
-          responseJson = JSON.parse(responseText);
-        } catch (parseError) {
-          console.error('Failed to parse JSON:', parseError);
-          throw new Error(`Invalid JSON response: ${responseText}`);
-        }
-
-        const sel = document.querySelector(".model-select");
-        if (!sel) {
-          throw new Error("Could not find model selector element");
-        }
-
-        // Clear the current options and add new ones
-        sel.innerHTML = '';
-          
-        const modelDict = responseJson["model pool"];
-        if (!modelDict) {
-          throw new Error("Response missing 'model pool' property");
-        }
-
-        Object.entries(modelDict).forEach(([key, value]) => {
-          const opt = document.createElement("option");
-          opt.value = key;
-          opt.textContent = value;
-          sel.appendChild(opt);
-        });
-
-        // Set initial value to the first model
-        const firstKey = Object.keys(modelDict)[0];
-        if (firstKey) {
-          sel.value = firstKey;
-          this.cstate.selectedModel = firstKey;
-        }
-      } catch (error) {
-        console.error("Error populating model selector:", error);
-        this.errorMessage = `Failed to load models: ${error.message}`;
-      }
-    },
-
     async handleImageUpload(event) {
       const file = event.target.files[0];
       if (file) {
@@ -169,29 +190,7 @@ document.addEventListener("alpine:init", () => {
         this.processMessage(value);
       } catch (error) {
         console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
+        this.setError(error);
         this.generating = false;
       }
     },
@@ -309,29 +308,7 @@ document.addEventListener("alpine:init", () => {
         }
       } catch (error) {
         console.error('error', error);
-        const errorDetails = {
-            message: error.message || 'Unknown error',
-            stack: error.stack,
-            name: error.name || 'Error'
-        };
-        
-        this.errorMessage = {
-            basic: `${errorDetails.name}: ${errorDetails.message}`,
-            stack: errorDetails.stack
-        };
-
-        // Clear any existing timeout
-        if (this.errorTimeout) {
-            clearTimeout(this.errorTimeout);
-        }
-
-        // Only set the timeout if the error details aren't expanded
-        if (!this.errorExpanded) {
-            this.errorTimeout = setTimeout(() => {
-                this.errorMessage = null;
-                this.errorExpanded = false;
-            }, 30 * 1000);
-        }
+        this.setError(error);
       } finally {
         this.generating = false;
       }
@@ -467,6 +444,106 @@ document.addEventListener("alpine:init", () => {
         this.fetchDownloadProgress();
       }, 1000); // Poll every second
     },
+
+    // Add a helper method to set errors consistently
+    setError(error) {
+      this.errorMessage = {
+        basic: error.message || "An unknown error occurred",
+        stack: error.stack || ""
+      };
+      this.errorExpanded = false;
+      
+      if (this.errorTimeout) {
+        clearTimeout(this.errorTimeout);
+      }
+
+      if (!this.errorExpanded) {
+        this.errorTimeout = setTimeout(() => {
+          this.errorMessage = null;
+          this.errorExpanded = false;
+        }, 30 * 1000);
+      }
+    },
+
+    async deleteModel(modelName, model) {
+      const downloadedSize = model.total_downloaded || 0;
+      const sizeMessage = downloadedSize > 0 ? 
+        `This will free up ${this.formatBytes(downloadedSize)} of space.` :
+        'This will remove any partially downloaded files.';
+      
+      if (!confirm(`Are you sure you want to delete ${model.name}? ${sizeMessage}`)) {
+        return;
+      }
+
+      try {
+        const response = await fetch(`${window.location.origin}/models/${modelName}`, {
+          method: 'DELETE',
+          headers: {
+            'Content-Type': 'application/json'
+          }
+        });
+
+        const data = await response.json();
+        
+        if (!response.ok) {
+          throw new Error(data.detail || 'Failed to delete model');
+        }
+
+        // Update the model status in the UI
+        if (this.models[modelName]) {
+          this.models[modelName].downloaded = false;
+          this.models[modelName].download_percentage = 0;
+          this.models[modelName].total_downloaded = 0;
+        }
+
+        // If this was the selected model, switch to a different one
+        if (this.cstate.selectedModel === modelName) {
+          const availableModel = Object.keys(this.models).find(key => this.models[key].downloaded);
+          this.cstate.selectedModel = availableModel || 'llama-3.2-1b';
+        }
+
+        // Show success message
+        console.log(`Model deleted successfully from: ${data.path}`);
+
+        // Refresh the model list
+        await this.populateSelector();
+      } catch (error) {
+        console.error('Error deleting model:', error);
+        this.setError(error.message || 'Failed to delete model');
+      }
+    },
+
+    async handleDownload(modelName) {
+      try {
+        const response = await fetch(`${window.location.origin}/download`, {
+          method: 'POST',
+          headers: {
+            'Content-Type': 'application/json'
+          },
+          body: JSON.stringify({
+            model: modelName
+          })
+        });
+
+        const data = await response.json();
+
+        if (!response.ok) {
+          throw new Error(data.error || 'Failed to start download');
+        }
+
+        // Update the model's status immediately when download starts
+        if (this.models[modelName]) {
+          this.models[modelName] = {
+            ...this.models[modelName],
+            loading: true
+          };
+        }
+
+      } catch (error) {
+        console.error('Error starting download:', error);
+        this.setError(error);
+      }
+    }
   }));
 });