浏览代码

Merge pull request #6 from cadenmackenzie/modelSideBarV2

Model side bar v2
Caden MacKenzie 8 月之前
父节点
当前提交
c469d5352c
共有 5 个文件被更改,包括 311 次插入68 次删除
  1. 60 1
      exo/api/chatgpt_api.py
  2. 8 2
      exo/download/hf/hf_shard_download.py
  3. 131 29
      exo/tinychat/index.css
  4. 48 9
      exo/tinychat/index.html
  5. 64 27
      exo/tinychat/index.js

+ 60 - 1
exo/api/chatgpt_api.py

@@ -19,6 +19,8 @@ from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable, Optional
 from exo.download.hf.hf_shard_download import HFShardDownloader
+import shutil
+from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -176,6 +178,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
+    cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
 
     if "__compiled__" not in globals():
       self.static_dir = Path(__file__).parent.parent/"tinychat"
@@ -243,13 +246,17 @@ class ChatGPTAPI:
               
               # Get overall percentage from status
               download_percentage = status.get("overall") if status else None
+              total_size = status.get("total_size") if status else None
+              total_downloaded = status.get("total_downloaded") if status else False
               if DEBUG >= 2 and download_percentage is not None:
                   print(f"Overall download percentage for {model_name}: {download_percentage}")
               
               model_pool[model_name] = {
                   "name": pretty,
                   "downloaded": download_percentage == 100 if download_percentage is not None else False,
-                  "download_percentage": download_percentage
+                  "download_percentage": download_percentage,
+                  "total_size": total_size,
+                  "total_downloaded": total_downloaded
               }
       
       return web.json_response({"model pool": model_pool})
@@ -410,6 +417,58 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
+  async def handle_delete_model(self, request):
+    try:
+      model_name = request.match_info.get('model_name')
+      if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
+      
+      if not model_name or model_name not in model_cards:
+        return web.json_response(
+          {"detail": f"Invalid model name: {model_name}"}, 
+          status=400
+          )
+
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard:
+        return web.json_response(
+          {"detail": "Could not build shard for model"}, 
+          status=400
+        )
+
+      repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+      if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
+      
+      # Get the HF cache directory using the helper function
+      hf_home = get_hf_home()
+      cache_dir = get_repo_root(repo_id)
+      
+      if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
+      
+      if os.path.exists(cache_dir):
+        if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
+        try:
+          shutil.rmtree(cache_dir)
+          return web.json_response({
+            "status": "success", 
+            "message": f"Model {model_name} deleted successfully",
+            "path": str(cache_dir)
+          })
+        except Exception as e:
+          return web.json_response({
+            "detail": f"Failed to delete model files: {str(e)}"
+          }, status=500)
+      else:
+        return web.json_response({
+          "detail": f"Model files not found at {cache_dir}"
+        }, status=404)
+            
+    except Exception as e:
+        print(f"Error in handle_delete_model: {str(e)}")
+        traceback.print_exc()
+        return web.json_response({
+            "detail": f"Server error: {str(e)}"
+        }, status=500)
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()

+ 8 - 2
exo/download/hf/hf_shard_download.py

@@ -1,7 +1,7 @@
 import asyncio
 import traceback
 from pathlib import Path
-from typing import Dict, List, Tuple, Optional
+from typing import Dict, List, Tuple, Optional, Union
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
@@ -90,7 +90,7 @@ class HFShardDownloader(ShardDownloader):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return self._on_progress
 
-  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+  async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
     if not self.current_shard or not self.current_repo_id:
       if DEBUG >= 2:
         print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
@@ -144,6 +144,12 @@ class HFShardDownloader(ShardDownloader):
           status["overall"] = (downloaded_bytes / total_bytes) * 100
         else:
           status["overall"] = 0
+          
+        # Add total size in bytes
+        status["total_size"] = total_bytes
+        if status["overall"] != 100:
+          status["total_downloaded"] = downloaded_bytes
+        
 
         if DEBUG >= 2:
           print(f"Download calculation for {self.current_repo_id}:")

+ 131 - 29
exo/tinychat/index.css

@@ -21,8 +21,8 @@ main {
 .home {
   width: 100%;
   height: 90%;
-
   margin-bottom: 10rem;
+  padding-top: 2rem;
 }
 
 .title {
@@ -129,8 +129,9 @@ main {
   flex-direction: column;
   gap: 1rem;
   align-items: center;
-  padding-top: 1rem;
+  padding-top: 4rem;
   padding-bottom: 11rem;
+  margin: 0 auto;
 }
 
 .message {
@@ -149,10 +150,17 @@ main {
   color: #000;
 }
 .download-progress {
-  margin-bottom: 12em;
+  position: fixed;
+  bottom: 11rem;
+  left: 50%;
+  transform: translateX(-50%);
+  margin-left: 125px;
+  width: 100%;
+  max-width: 1200px;
   overflow-y: auto;
   min-height: 350px;
   padding: 2rem;
+  z-index: 998;
 }
 .message > pre {
   white-space: pre-wrap;
@@ -271,23 +279,24 @@ main {
 }
 
 .input-container {
-  position: absolute;
+  position: fixed;
   bottom: 0;
-
-  /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(
-    0deg,
-    var(--primary-bg-color) 55%,
-    transparent 100%
-  );
-
-  width: 100%;
+  left: 250px;
+  width: calc(100% - 250px);
   max-width: 1200px;
   display: flex;
   flex-direction: column;
   justify-content: center;
   align-items: center;
   z-index: 999;
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
+  left: 50%;
+  transform: translateX(-50%);
+  margin-left: 125px;
 }
 
 .input-performance {
@@ -372,22 +381,7 @@ p {
 }
 
 .model-selector {
-  display: flex;
-  justify-content: center;
-  padding: 20px 0;
-}
-.model-selector select {
-  padding: 10px 20px;
-  font-size: 16px;
-  border: 1px solid #ccc;
-  border-radius: 5px;
-  background-color: #f8f8f8;
-  cursor: pointer;
-}
-.model-selector select:focus {
-  outline: none;
-  border-color: #007bff;
-  box-shadow: 0 0 0 2px rgba(0,123,255,.25);
+  display: none;
 }
 
 /* Image upload button styles */
@@ -481,4 +475,112 @@ p {
 
 .clear-history-button i {
   font-size: 14px;
+}
+
+/* Add new sidebar styles */
+.sidebar {
+  position: fixed;
+  left: 0;
+  top: 0;
+  bottom: 0;
+  width: 250px;
+  background-color: var(--secondary-color);
+  padding: 20px;
+  overflow-y: auto;
+  z-index: 1000;
+}
+
+.model-option {
+  padding: 12px;
+  margin: 8px 0;
+  border-radius: 8px;
+  background-color: var(--primary-bg-color);
+  cursor: pointer;
+  transition: all 0.2s ease;
+}
+
+.model-option:hover {
+  transform: translateX(5px);
+}
+
+.model-option.selected {
+  border-left: 3px solid var(--primary-color);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-name {
+  font-weight: bold;
+  margin-bottom: 4px;
+}
+
+.model-progress {
+  font-size: 0.9em;
+  color: var(--secondary-color-transparent);
+}
+
+/* Adjust main content to accommodate sidebar */
+main {
+  margin-left: 250px;
+  width: calc(100% - 250px);
+}
+
+/* Add styles for the back button */
+.back-button {
+  position: fixed;
+  top: 1rem;
+  left: calc(250px + 1rem); /* Sidebar width + padding */
+  background-color: var(--secondary-color);
+  color: var(--foreground-color);
+  padding: 0.5rem 1rem;
+  border-radius: 8px;
+  border: none;
+  cursor: pointer;
+  display: flex;
+  align-items: center;
+  gap: 0.5rem;
+  z-index: 1000;
+  transition: all 0.2s ease;
+}
+
+.back-button:hover {
+  transform: translateX(-5px);
+  background-color: var(--secondary-color-transparent);
+}
+
+.model-info {
+  display: flex;
+  flex-direction: column;
+  gap: 4px;
+}
+
+.model-size {
+  font-size: 0.8em;
+  color: var(--secondary-color-transparent);
+  opacity: 0.8;
+}
+
+.model-header {
+    display: flex;
+    justify-content: space-between;
+    align-items: center;
+    margin-bottom: 4px;
+}
+
+.model-delete-button {
+    background: none;
+    border: none;
+    color: var(--red-color);
+    padding: 4px 8px;
+    cursor: pointer;
+    transition: all 0.2s ease;
+    opacity: 0.7;
+}
+
+.model-delete-button:hover {
+    opacity: 1;
+    transform: scale(1.1);
+}
+
+.model-option:hover .model-delete-button {
+    opacity: 1;
 }

+ 48 - 9
exo/tinychat/index.html

@@ -25,7 +25,38 @@
 </head>
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
-     <!-- Error Toast -->
+  <div class="sidebar">
+    <h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
+    <template x-for="(model, key) in models" :key="key">
+        <div class="model-option" 
+             :class="{ 'selected': cstate.selectedModel === key }"
+             @click="cstate.selectedModel = key">
+            <div class="model-header">
+                <div class="model-name" x-text="model.name"></div>
+                <button 
+                    @click.stop="deleteModel(key, model)"
+                    class="model-delete-button"
+                    x-show="model.download_percentage > 0">
+                    <i class="fas fa-trash"></i>
+                </button>
+            </div>
+            <div class="model-info">
+                <div class="model-progress">
+                    <template x-if="model.download_percentage != null">
+                        <span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
+                    </template>
+                </div>
+                <template x-if="model.total_size">
+                    <div class="model-size" x-text="model.total_downloaded ? 
+                        `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` : 
+                        formatBytes(model.total_size)">
+                    </div>
+                </template>
+            </div>
+        </div>
+    </template>
+  </div> 
+    <!-- Error Toast -->
     <div x-show="errorMessage !== null" x-transition.opacity class="toast">
         <div class="toast-header">
             <span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
@@ -44,10 +75,7 @@
             <span x-text="errorMessage?.stack || ''"></span>
         </div>
     </div>
-<div class="model-selector">
-  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" class='model-select'>
-  </select>
-</div>
+
 <div @popstate.window="
       if (home === 2) {
         home = -1;
@@ -79,10 +107,8 @@
 <template x-for="_state in histories.toSorted((a, b) =&gt; b.time - a.time)">
 <div @click="
             cstate = _state;
-            if (cstate) cstate.selectedModel = document.querySelector('.model-selector select').value
-            // updateTotalTokens(cstate.messages);
-            home = 1;
-            // ensure that going back in history will go back to home
+            if (!cstate.selectedModel) cstate.selectedModel = 'llama-3.2-1b';
+            home = 2;
             window.history.pushState({}, '', '/');
           " @touchend="
             if (Math.abs($event.changedTouches[0].clientX - otx) &gt; trigger) removeHistory(_state);
@@ -108,6 +134,19 @@
 </template>
 </div>
 </div>
+<button 
+    @click="
+        home = 0;
+        cstate = { time: null, messages: [], selectedModel: cstate.selectedModel };
+        time_till_first = 0;
+        tokens_per_second = 0;
+        total_tokens = 0;
+    " 
+    class="back-button"
+    x-show="home === 2">
+    <i class="fas fa-arrow-left"></i>
+    Back to Chats
+</button>
 <div class="messages" x-init="
       $watch('cstate', value =&gt; {
         $el.innerHTML = '';

+ 64 - 27
exo/tinychat/index.js

@@ -36,6 +36,9 @@ document.addEventListener("alpine:init", () => {
 
     modelPoolInterval: null,
 
+    // Add models state alongside existing state
+    models: {},
+
     init() {
       // Clean up any pending messages
       localStorage.removeItem("pendingMessage");
@@ -93,34 +96,20 @@ document.addEventListener("alpine:init", () => {
 
         const data = await response.json();
         
-        const sel = document.querySelector('.model-select');
-        
-        // Only create options if they don't exist
-        if (sel.children.length === 0) {
-          Object.entries(data["model pool"]).forEach(([key, value]) => {
-            const opt = document.createElement("option");
-            opt.value = key;
-            opt.dataset.modelName = value.name;  // Store base name in dataset
-            opt.textContent = value.name;
-            sel.appendChild(opt);
-          });
-        }
-        
-        // Update existing options text
-        Array.from(sel.options).forEach(opt => {
-          const modelInfo = data["model pool"][opt.value];
-          if (modelInfo) {
-            let displayText = modelInfo.name;
-            if (modelInfo.download_percentage != null) {
-              if (modelInfo.downloaded) {
-                  displayText += ' (downloaded)';
-              } else {
-                  displayText += ` (${Math.round(modelInfo.download_percentage)}% downloaded)`;
-              }
-            }
-            opt.textContent = displayText;
+        // Update the models state with the full model pool data
+        Object.entries(data["model pool"]).forEach(([key, value]) => {
+          if (!this.models[key]) {
+            this.models[key] = value;
+          } else {
+            // Update existing model info while preserving reactivity
+            this.models[key].name = value.name;
+            this.models[key].downloaded = value.downloaded;
+            this.models[key].download_percentage = value.download_percentage;
+            this.models[key].total_size = value.total_size;
+            this.models[key].total_downloaded = value.total_downloaded;
           }
-      });
+        });
+                
       } catch (error) {
         console.error("Error populating model selector:", error);
         this.setError(error);
@@ -446,6 +435,54 @@ document.addEventListener("alpine:init", () => {
         }, 30 * 1000);
       }
     },
+
+    async deleteModel(modelName, model) {
+      const downloadedSize = model.total_downloaded || 0;
+      const sizeMessage = downloadedSize > 0 ? 
+        `This will free up ${this.formatBytes(downloadedSize)} of space.` :
+        'This will remove any partially downloaded files.';
+      
+      if (!confirm(`Are you sure you want to delete ${model.name}? ${sizeMessage}`)) {
+        return;
+      }
+
+      try {
+        const response = await fetch(`${window.location.origin}/models/${modelName}`, {
+          method: 'DELETE',
+          headers: {
+            'Content-Type': 'application/json'
+          }
+        });
+
+        const data = await response.json();
+        
+        if (!response.ok) {
+          throw new Error(data.detail || 'Failed to delete model');
+        }
+
+        // Update the model status in the UI
+        if (this.models[modelName]) {
+          this.models[modelName].downloaded = false;
+          this.models[modelName].download_percentage = 0;
+          this.models[modelName].total_downloaded = 0;
+        }
+
+        // If this was the selected model, switch to a different one
+        if (this.cstate.selectedModel === modelName) {
+          const availableModel = Object.keys(this.models).find(key => this.models[key].downloaded);
+          this.cstate.selectedModel = availableModel || 'llama-3.2-1b';
+        }
+
+        // Show success message
+        console.log(`Model deleted successfully from: ${data.path}`);
+
+        // Refresh the model list
+        await this.populateSelector();
+      } catch (error) {
+        console.error('Error deleting model:', error);
+        this.setError(error.message || 'Failed to delete model');
+      }
+    }
   }));
 });