Ver código fonte

clean up download progress

Alex Cheema 7 meses atrás
pai
commit
4746ffdd60
3 arquivos alterados com 8 adições e 50 exclusões
  1. 4 8
      exo/api/chatgpt_api.py
  2. 0 4
      exo/main.py
  3. 4 38
      exo/tinychat/index.js

+ 4 - 8
exo/api/chatgpt_api.py

@@ -209,14 +209,10 @@ class ChatGPTAPI:
   async def handle_get_download_progress(self, request):
   async def handle_get_download_progress(self, request):
     progress_data = {}
     progress_data = {}
     for node_id, progress_event in self.node.node_download_progress.items():
     for node_id, progress_event in self.node.node_download_progress.items():
-        if isinstance(progress_event, RepoProgressEvent):
-            # Convert to dict if not already
-            progress_data[node_id] = progress_event.to_dict()
-        elif isinstance(progress_event, dict):
-            progress_data[node_id] = progress_event
-        else:
-            # Handle unexpected types
-            progress_data[node_id] = str(progress_event)
+      if isinstance(progress_event, RepoProgressEvent):
+        progress_data[node_id] = progress_event.to_dict()
+      else:
+        print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
     return web.json_response(progress_data)
     return web.json_response(progress_data)
 
 
 
 

+ 0 - 4
exo/main.py

@@ -90,8 +90,6 @@ node = StandardNode(
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
-node.download_progress = {}  # Initialize download progress tracking
-node.node_download_progress = {}  # For tracking per-node download progress
 api = ChatGPTAPI(
 api = ChatGPTAPI(
   node,
   node,
   inference_engine.__class__.__name__,
   inference_engine.__class__.__name__,
@@ -125,8 +123,6 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
     current_time = time.time()
     current_time = time.time()
     if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
     if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
         last_broadcast_time = current_time
         last_broadcast_time = current_time
-        node.download_progress[event.repo_id] = event.to_dict()
-        node.node_download_progress[node.id] = event.to_dict()
         asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
         asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
             "type": "download_progress",
             "type": "download_progress",
             "node_id": node.id,
             "node_id": node.id,

+ 4 - 38
exo/tinychat/index.js

@@ -32,8 +32,10 @@ document.addEventListener("alpine:init", () => {
 
 
     init() {
     init() {
       // Clean up any pending messages
       // Clean up any pending messages
-      this.pendingMessage = null;
       localStorage.removeItem("pendingMessage");
       localStorage.removeItem("pendingMessage");
+
+      // Start polling for download progress
+      this.startDownloadProgressPolling();
     },
     },
 
 
     removeHistory(cstate) {
     removeHistory(cstate) {
@@ -107,36 +109,14 @@ document.addEventListener("alpine:init", () => {
 
 
         // Proceed to handle the message
         // Proceed to handle the message
         this.processMessage(value);
         this.processMessage(value);
-
-        // Start polling for download progress
-        this.startDownloadProgressPolling();
-
-        // Delay the check for downloadProgress by 8 seconds without blocking execution
-        setTimeout(async () => {
-          this.pendingMessageHandler(value);
-        }, 8000);
-
       } catch (error) {
       } catch (error) {
         console.error('error', error)
         console.error('error', error)
         this.errorMessage = error.message || 'Errore durante l\'invio del messaggio.';
         this.errorMessage = error.message || 'Errore durante l\'invio del messaggio.';
         setTimeout(() => {
         setTimeout(() => {
           this.errorMessage = null;
           this.errorMessage = null;
         }, 5 * 1000)
         }, 5 * 1000)
-      }
-    },
-
-    async pendingMessageHandler(value) {
-      console.log("Pending message handler called");
-      // Check if download is in progress
-      if (this.downloadProgress && this.downloadProgress.status !== "complete") {
-        // Save the message in pendingMessage
-        this.pendingMessage = value;
         localStorage.setItem("pendingMessage", value);
         localStorage.setItem("pendingMessage", value);
-        console.log("Pending message saved:", localStorage.getItem("pendingMessage"));
-        // Inform the user
-        this.cstate.messages.push({ role: "system", content: "Download is in progress. Your message will be processed once the download completes." });
-        this.generating = false; // Reset generating
-        return;
+        this.generating = false;
       }
       }
     },
     },
 
 
@@ -325,8 +305,6 @@ document.addEventListener("alpine:init", () => {
 
 
     async fetchDownloadProgress() {
     async fetchDownloadProgress() {
       try {
       try {
-        console.log("fetching download progress");
-        await new Promise(resolve => setTimeout(resolve, 4000)); // Necessary delay
         const response = await fetch(`${this.endpoint}/download/progress`);
         const response = await fetch(`${this.endpoint}/download/progress`);
         if (response.ok) {
         if (response.ok) {
           const data = await response.json();
           const data = await response.json();
@@ -336,8 +314,6 @@ document.addEventListener("alpine:init", () => {
             // Check if download is complete
             // Check if download is complete
             if (progress.status === "complete" || progress.status === "failed") {
             if (progress.status === "complete" || progress.status === "failed") {
               this.downloadProgress = null; // Hide the progress section
               this.downloadProgress = null; // Hide the progress section
-              // Stop polling
-              this.stopDownloadProgressPolling();
 
 
               if (progress.status === "complete") {
               if (progress.status === "complete") {
                 // Download is complete
                 // Download is complete
@@ -345,7 +321,6 @@ document.addEventListener("alpine:init", () => {
                 const savedMessage = localStorage.getItem("pendingMessage");
                 const savedMessage = localStorage.getItem("pendingMessage");
                 if (savedMessage) {
                 if (savedMessage) {
                   // Clear pendingMessage
                   // Clear pendingMessage
-                  this.pendingMessage = null;
                   localStorage.removeItem("pendingMessage");
                   localStorage.removeItem("pendingMessage");
                   // Call processMessage() with savedMessage
                   // Call processMessage() with savedMessage
                   await this.processMessage(savedMessage);
                   await this.processMessage(savedMessage);
@@ -364,8 +339,6 @@ document.addEventListener("alpine:init", () => {
           } else {
           } else {
             // No ongoing download
             // No ongoing download
             this.downloadProgress = null;
             this.downloadProgress = null;
-            // Stop polling
-            this.stopDownloadProgressPolling();
           }
           }
         }
         }
       } catch (error) {
       } catch (error) {
@@ -386,13 +359,6 @@ document.addEventListener("alpine:init", () => {
         this.fetchDownloadProgress();
         this.fetchDownloadProgress();
       }, 1000); // Poll every second
       }, 1000); // Poll every second
     },
     },
-
-    stopDownloadProgressPolling() {
-      if (this.downloadProgressInterval) {
-        clearInterval(this.downloadProgressInterval);
-        this.downloadProgressInterval = null;
-      }
-    },
   }));
   }));
 });
 });