Parcourir la source

backend endpoint now uses SSE to send each model as its loaded. also shows loading indicator until first model shows up

cadenmackenzie il y a 8 mois
Parent
commit
e16170cfcb
4 fichiers modifiés avec 96 ajouts et 71 suppressions
  1. 55 47
      exo/api/chatgpt_api.py
  2. 17 0
      exo/tinychat/index.css
  3. 7 0
      exo/tinychat/index.html
  4. 17 24
      exo/tinychat/index.js

+ 55 - 47
exo/api/chatgpt_api.py

@@ -219,54 +219,62 @@ class ChatGPTAPI:
 
   async def handle_model_support(self, request):
     try:
-      model_pool = {}
-      
-      for model_name, pretty in pretty_name.items():
-        if model_name in model_cards:
-          model_info = model_cards[model_name]
-          
-          # Get required engines from the node's topology directly
-          required_engines = list(dict.fromkeys(
-              [engine_name for engine_list in self.node.topology_inference_engines_pool 
-               for engine_name in engine_list 
-               if engine_name is not None] + 
-              [self.inference_engine_classname]
-          ))          
-          # Check if model supports required engines
-          if all(map(lambda engine: engine in model_info["repo"], required_engines)):
-            shard = build_base_shard(model_name, self.inference_engine_classname)
-            if shard:
-                # Use HFShardDownloader to check status without initiating download
-              downloader = HFShardDownloader(quick_check=True)  # quick_check=True prevents downloads
-              downloader.current_shard = shard
-              downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
-              status = await downloader.get_shard_download_status()
-              if DEBUG >= 2:
-                  print(f"Download status for {model_name}: {status}")
-              
-              # Get overall percentage from status
-              download_percentage = status.get("overall") if status else None
-              total_size = status.get("total_size") if status else None
-              total_downloaded = status.get("total_downloaded") if status else False
-              if DEBUG >= 2 and download_percentage is not None:
-                  print(f"Overall download percentage for {model_name}: {download_percentage}")
-              
-              model_pool[model_name] = {
-                  "name": pretty,
-                  "downloaded": download_percentage == 100 if download_percentage is not None else False,
-                  "download_percentage": download_percentage,
-                  "total_size": total_size,
-                  "total_downloaded": total_downloaded
-              }
-      
-      return web.json_response({"model pool": model_pool})
+        response = web.StreamResponse(
+            status=200,
+            reason='OK',
+            headers={
+                'Content-Type': 'text/event-stream',
+                'Cache-Control': 'no-cache',
+                'Connection': 'keep-alive',
+            }
+        )
+        await response.prepare(request)
+        
+        for model_name, pretty in pretty_name.items():
+            if model_name in model_cards:
+                model_info = model_cards[model_name]
+                
+                required_engines = list(dict.fromkeys(
+                    [engine_name for engine_list in self.node.topology_inference_engines_pool 
+                     for engine_name in engine_list 
+                     if engine_name is not None] + 
+                    [self.inference_engine_classname]
+                ))
+                
+                if all(map(lambda engine: engine in model_info["repo"], required_engines)):
+                    shard = build_base_shard(model_name, self.inference_engine_classname)
+                    if shard:
+                        downloader = HFShardDownloader(quick_check=True)
+                        downloader.current_shard = shard
+                        downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
+                        status = await downloader.get_shard_download_status()
+                        
+                        download_percentage = status.get("overall") if status else None
+                        total_size = status.get("total_size") if status else None
+                        total_downloaded = status.get("total_downloaded") if status else False
+                        
+                        model_data = {
+                            model_name: {
+                                "name": pretty,
+                                "downloaded": download_percentage == 100 if download_percentage is not None else False,
+                                "download_percentage": download_percentage,
+                                "total_size": total_size,
+                                "total_downloaded": total_downloaded
+                            }
+                        }
+                        
+                        await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
+        
+        await response.write(b"data: [DONE]\n\n")
+        return response
+        
     except Exception as e:
-      print(f"Error in handle_model_support: {str(e)}")
-      traceback.print_exc()
-      return web.json_response(
-        {"detail": f"Server error: {str(e)}"}, 
-        status=500
-      )
+        print(f"Error in handle_model_support: {str(e)}")
+        traceback.print_exc()
+        return web.json_response(
+            {"detail": f"Server error: {str(e)}"}, 
+            status=500
+        )
 
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])

+ 17 - 0
exo/tinychat/index.css

@@ -583,4 +583,21 @@ main {
 
 .model-option:hover .model-delete-button {
     opacity: 1;
+}
+
+.loading-container {
+    display: flex;
+    flex-direction: column;
+    align-items: center;
+    gap: 10px;
+    padding: 20px;
+    color: var(--secondary-color-transparent);
+}
+
+.loading-container i {
+    font-size: 24px;
+}
+
+.loading-container span {
+    font-size: 14px;
 }

+ 7 - 0
exo/tinychat/index.html

@@ -27,6 +27,13 @@
 <main x-data="state" x-init="console.log(endpoint)">
   <div class="sidebar">
     <h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
+    
+    <!-- Loading indicator -->
+    <div class="loading-container" x-show="Object.keys(models).length === 0">
+        <i class="fas fa-spinner fa-spin"></i>
+        <span>Loading models...</span>
+    </div>
+    
     <template x-for="(model, key) in models" :key="key">
         <div class="model-option" 
              :class="{ 'selected': cstate.selectedModel === key }"

+ 17 - 24
exo/tinychat/index.js

@@ -88,32 +88,25 @@ document.addEventListener("alpine:init", () => {
     },
 
     async populateSelector() {
-      try {
-        const response = await fetch(`${window.location.origin}/modelpool`);
-        if (!response.ok) {
-          throw new Error(`HTTP error! status: ${response.status}`);
+      const evtSource = new EventSource(`${window.location.origin}/modelpool`);
+      
+      evtSource.onmessage = (event) => {
+        if (event.data === "[DONE]") {
+          evtSource.close();
+          return;
         }
-
-        const data = await response.json();
         
-        // Update the models state with the full model pool data
-        Object.entries(data["model pool"]).forEach(([key, value]) => {
-          if (!this.models[key]) {
-            this.models[key] = value;
-          } else {
-            // Update existing model info while preserving reactivity
-            this.models[key].name = value.name;
-            this.models[key].downloaded = value.downloaded;
-            this.models[key].download_percentage = value.download_percentage;
-            this.models[key].total_size = value.total_size;
-            this.models[key].total_downloaded = value.total_downloaded;
-          }
-        });
-                
-      } catch (error) {
-        console.error("Error populating model selector:", error);
-        this.setError(error);
-      }
+        const modelData = JSON.parse(event.data);
+        this.models = {
+          ...this.models,
+          ...modelData
+        };
+      };
+      
+      evtSource.onerror = (error) => {
+        console.error('EventSource failed:', error);
+        evtSource.close();
+      };
     },
 
     async handleImageUpload(event) {