1
0
Эх сурвалжийг харах

adding amount that has been downloaded if model is not fully downloaded

cadenmackenzie 8 сар өмнө
parent
commit
fb3baf5037

+ 5 - 1
exo/api/chatgpt_api.py

@@ -244,13 +244,17 @@ class ChatGPTAPI:
               
               # Get overall percentage from status
               download_percentage = status.get("overall") if status else None
+              total_size = status.get("total_size") if status else None
+              total_downloaded = status.get("total_downloaded") if status else False
               if DEBUG >= 2 and download_percentage is not None:
                   print(f"Overall download percentage for {model_name}: {download_percentage}")
               
               model_pool[model_name] = {
                   "name": pretty,
                   "downloaded": download_percentage == 100 if download_percentage is not None else False,
-                  "download_percentage": download_percentage
+                  "download_percentage": download_percentage,
+                  "total_size": total_size,
+                  "total_downloaded": total_downloaded
               }
       
       return web.json_response({"model pool": model_pool})

+ 8 - 2
exo/download/hf/hf_shard_download.py

@@ -1,7 +1,7 @@
 import asyncio
 import traceback
 from pathlib import Path
-from typing import Dict, List, Tuple, Optional
+from typing import Dict, List, Tuple, Optional, Union
 from exo.inference.shard import Shard
 from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
@@ -90,7 +90,7 @@ class HFShardDownloader(ShardDownloader):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     return self._on_progress
 
-  async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
+  async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
     if not self.current_shard or not self.current_repo_id:
       if DEBUG >= 2:
         print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
@@ -144,6 +144,12 @@ class HFShardDownloader(ShardDownloader):
           status["overall"] = (downloaded_bytes / total_bytes) * 100
         else:
           status["overall"] = 0
+          
+        # Add total size in bytes
+        status["total_size"] = total_bytes
+        if status["overall"] != 100:
+          status["total_downloaded"] = downloaded_bytes
+        
 
         if DEBUG >= 2:
           print(f"Download calculation for {self.current_repo_id}:")

+ 12 - 0
exo/tinychat/index.css

@@ -545,4 +545,16 @@ main {
 .back-button:hover {
   transform: translateX(-5px);
   background-color: var(--secondary-color-transparent);
+}
+
+.model-info {
+  display: flex;
+  flex-direction: column;
+  gap: 4px;
+}
+
+.model-size {
+  font-size: 0.8em;
+  color: var(--secondary-color-transparent);
+  opacity: 0.8;
 }

+ 11 - 3
exo/tinychat/index.html

@@ -32,9 +32,17 @@
              :class="{ 'selected': cstate.selectedModel === key }"
              @click="cstate.selectedModel = key">
             <div class="model-name" x-text="model.name"></div>
-            <div class="model-progress">
-                <template x-if="model.download_percentage != null">
-                    <span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
+            <div class="model-info">
+                <div class="model-progress">
+                    <template x-if="model.download_percentage != null">
+                        <span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
+                    </template>
+                </div>
+                <template x-if="model.total_size">
+                    <div class="model-size" x-text="model.total_downloaded ? 
+                        `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` : 
+                        formatBytes(model.total_size)">
+                    </div>
                 </template>
             </div>
         </div>

+ 3 - 1
exo/tinychat/index.js

@@ -105,9 +105,11 @@ document.addEventListener("alpine:init", () => {
             this.models[key].name = value.name;
             this.models[key].downloaded = value.downloaded;
             this.models[key].download_percentage = value.download_percentage;
+            this.models[key].total_size = value.total_size;
+            this.models[key].total_downloaded = value.total_downloaded;
           }
         });
-        
+                
       } catch (error) {
         console.error("Error populating model selector:", error);
         this.setError(error);