Răsfoiți Sursa

Merge pull request #258 from DevEmilio96/main

Display simple download status in web UI
Alex Cheema 7 luni în urmă
părinte
comite
6b38346974
5 a modificat fișierele cu 187 adăugiri și 7 ștergeri
  1. 17 0
      exo/api/chatgpt_api.py
  2. 13 6
      exo/main.py
  3. 3 1
      exo/tinychat/index.css
  4. 15 0
      exo/tinychat/index.html
  5. 139 0
      exo/tinychat/index.js

+ 17 - 0
exo/api/chatgpt_api.py

@@ -9,6 +9,7 @@ from aiohttp import web
 import aiohttp_cors
 import traceback
 from exo import DEBUG, VERSION
+from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
@@ -175,6 +176,8 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    # Endpoint for download progress tracking
+    cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
 
     self.static_dir = Path(__file__).parent.parent/"tinychat"
     self.app.router.add_get("/", self.handle_root)
@@ -203,6 +206,20 @@ class ChatGPTAPI:
     tokenizer = await resolve_tokenizer(shard.model_id)
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
 
+  async def handle_get_download_progress(self, request):
+    progress_data = {}
+    for node_id, progress_event in self.node.node_download_progress.items():
+        if isinstance(progress_event, RepoProgressEvent):
+            # Convert to dict if not already
+            progress_data[node_id] = progress_event.to_dict()
+        elif isinstance(progress_event, dict):
+            progress_data[node_id] = progress_event
+        else:
+            # Handle unexpected types
+            progress_data[node_id] = str(progress_event)
+    return web.json_response(progress_data)
+
+
   async def handle_post_chat_completions(self, request):
     data = await request.json()
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")

+ 13 - 6
exo/main.py

@@ -90,6 +90,8 @@ node = StandardNode(
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
+node.download_progress = {}  # Initialize download progress tracking
+node.node_download_progress = {}  # For tracking per-node download progress
 api = ChatGPTAPI(
   node,
   inference_engine.__class__.__name__,
@@ -119,12 +121,17 @@ if args.prometheus_client_port:
 last_broadcast_time = 0
 
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-  global last_broadcast_time
-  current_time = time.time()
-  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-    last_broadcast_time = current_time
-    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
-
+    global last_broadcast_time
+    current_time = time.time()
+    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+        last_broadcast_time = current_time
+        node.download_progress[event.repo_id] = event.to_dict()
+        node.node_download_progress[node.id] = event.to_dict()
+        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
+            "type": "download_progress",
+            "node_id": node.id,
+            "progress": event.to_dict()
+        })))
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 

+ 3 - 1
exo/tinychat/index.css

@@ -164,7 +164,9 @@ main {
   border-right: 2px solid var(--secondary-color);
   box-shadow: 10px 10px 20px 2px var(--secondary-color-transparent);
 }
-
+.download-progress{
+  margin-bottom: 20em;
+}
 .message > pre {
   white-space: pre-wrap;
 }

+ 15 - 0
exo/tinychat/index.html

@@ -149,6 +149,21 @@
       $el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
     " x-ref="messages" x-show="home === 2" x-transition="">
 </div>
+
+<!-- Download Progress Section -->
+<template x-if="downloadProgress">
+<div class="download-progress message message-role-assistant">
+  <h2>Download Progress</h2>
+  <div class="download-progress-node">
+    <p><strong>Model:</strong> <span x-text="downloadProgress.repo_id + '@' + downloadProgress.repo_revision"></span></p>
+    <p><strong>Progress:</strong> <span x-text="`${downloadProgress.downloaded_bytes_display} / ${downloadProgress.total_bytes_display} (${downloadProgress.percentage}%)`"></span></p>
+    <p><strong>Speed:</strong> <span x-text="downloadProgress.overall_speed_display || 'N/A'"></span></p>
+    <p><strong>ETA:</strong> <span x-text="downloadProgress.overall_eta_display || 'N/A'"></span></p>
+  </div>
+</div>
+</template>
+
+
 <div class="input-container">
 <div class="input-performance">
 <span class="input-performance-point">

+ 139 - 0
exo/tinychat/index.js

@@ -23,6 +23,19 @@ document.addEventListener("alpine:init", () => {
     // image handling
     imagePreview: null,
 
+    // download progress
+    downloadProgress: null,
+    downloadProgressInterval: null, // To keep track of the polling interval
+
+    // Pending message storage
+    pendingMessage: null,
+
+    init() {
+      // Clean up any pending messages
+      this.pendingMessage = null;
+      localStorage.removeItem("pendingMessage");
+    },
+
     removeHistory(cstate) {
       const index = this.histories.findIndex((state) => {
         return state.time === cstate.time;
@@ -32,6 +45,24 @@ document.addEventListener("alpine:init", () => {
         localStorage.setItem("histories", JSON.stringify(this.histories));
       }
     },
+    // Utility functions
+    formatBytes(bytes) {
+      if (bytes === 0) return '0 B';
+      const k = 1024;
+      const sizes = ['B', 'KB', 'MB', 'GB', 'TB'];
+      const i = Math.floor(Math.log(bytes) / Math.log(k));
+      return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i];
+    },
+
+    formatDuration(seconds) {
+      if (seconds === null || seconds === undefined || isNaN(seconds)) return '';
+      const h = Math.floor(seconds / 3600);
+      const m = Math.floor((seconds % 3600) / 60);
+      const s = Math.floor(seconds % 60);
+      if (h > 0) return `${h}h ${m}m ${s}s`;
+      if (m > 0) return `${m}m ${s}s`;
+      return `${s}s`;
+    },
 
     async handleImageUpload(event) {
       const file = event.target.files[0];
@@ -74,6 +105,43 @@ document.addEventListener("alpine:init", () => {
         el.style.height = "auto";
         el.style.height = el.scrollHeight + "px";
 
+        // Proceed to handle the message
+        this.processMessage(value);
+
+        // Start polling for download progress
+        this.startDownloadProgressPolling();
+
+        // Delay the check for downloadProgress by 8 seconds without blocking execution
+        setTimeout(async () => {
+          this.pendingMessageHandler(value);
+        }, 8000);
+
+      } catch (error) {
+        console.error('error', error)
+        this.errorMessage = error.message || 'Errore durante l\'invio del messaggio.';
+        setTimeout(() => {
+          this.errorMessage = null;
+        }, 5 * 1000)
+      }
+    },
+
+    async pendingMessageHandler(value) {
+      console.log("Pending message handler called");
+      // Check if download is in progress
+      if (this.downloadProgress && this.downloadProgress.status !== "complete") {
+        // Save the message in pendingMessage
+        this.pendingMessage = value;
+        localStorage.setItem("pendingMessage", value);
+        console.log("Pending message saved:", localStorage.getItem("pendingMessage"));
+        // Inform the user
+        this.cstate.messages.push({ role: "system", content: "Download is in progress. Your message will be processed once the download completes." });
+        this.generating = false; // Reset generating
+        return;
+      }
+    },
+
+    async processMessage(value) {
+      try {
         // reset performance tracking
         const prefill_start = Date.now();
         let start_time = 0;
@@ -254,6 +322,77 @@ document.addEventListener("alpine:init", () => {
         }
       }
     },
+
+    async fetchDownloadProgress() {
+      try {
+        console.log("fetching download progress");
+        await new Promise(resolve => setTimeout(resolve, 4000)); // Necessary delay
+        const response = await fetch(`${this.endpoint}/download/progress`);
+        if (response.ok) {
+          const data = await response.json();
+          const progressArray = Object.values(data);
+          if (progressArray.length > 0) {
+            const progress = progressArray[0];
+            // Check if download is complete
+            if (progress.status === "complete" || progress.status === "failed") {
+              this.downloadProgress = null; // Hide the progress section
+              // Stop polling
+              this.stopDownloadProgressPolling();
+
+              if (progress.status === "complete") {
+                // Download is complete
+                // Check for pendingMessage
+                const savedMessage = localStorage.getItem("pendingMessage");
+                if (savedMessage) {
+                  // Clear pendingMessage
+                  this.pendingMessage = null;
+                  localStorage.removeItem("pendingMessage");
+                  // Call processMessage() with savedMessage
+                  await this.processMessage(savedMessage);
+                }
+              }
+            } else {
+              // Compute human-readable strings
+              progress.downloaded_bytes_display = this.formatBytes(progress.downloaded_bytes);
+              progress.total_bytes_display = this.formatBytes(progress.total_bytes);
+              progress.overall_speed_display = progress.overall_speed ? this.formatBytes(progress.overall_speed) + '/s' : '';
+              progress.overall_eta_display = progress.overall_eta ? this.formatDuration(progress.overall_eta) : '';
+              progress.percentage = ((progress.downloaded_bytes / progress.total_bytes) * 100).toFixed(2);
+
+              this.downloadProgress = progress;
+            }
+          } else {
+            // No ongoing download
+            this.downloadProgress = null;
+            // Stop polling
+            this.stopDownloadProgressPolling();
+          }
+        }
+      } catch (error) {
+        console.error("Error fetching download progress:", error);
+        this.downloadProgress = null;
+        // Stop polling in case of error
+        this.stopDownloadProgressPolling();
+      }
+    },
+
+    startDownloadProgressPolling() {
+      if (this.downloadProgressInterval) {
+        // Already polling
+        return;
+      }
+      this.fetchDownloadProgress(); // Fetch immediately
+      this.downloadProgressInterval = setInterval(() => {
+        this.fetchDownloadProgress();
+      }, 1000); // Poll every second
+    },
+
+    stopDownloadProgressPolling() {
+      if (this.downloadProgressInterval) {
+        clearInterval(this.downloadProgressInterval);
+        this.downloadProgressInterval = null;
+      }
+    },
   }));
 });