Bläddra i källkod

Merge pull request #653 from exo-explore/tinyfixes

Tiny fixes
Alex Cheema 3 månader sedan
förälder
incheckning
24c410c19c
5 ändrade filer med 63 tillägg och 58 borttagningar
  1. 3 3
      exo/download/new_shard_download.py
  2. 5 4
      exo/inference/tinygrad/inference.py
  3. 7 5
      exo/main.py
  4. 47 45
      exo/viz/topology_viz.py
  5. 1 1
      setup.py

+ 3 - 3
exo/download/new_shard_download.py

@@ -105,7 +105,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog
   elapsed_time = time.time() - all_start_time
   all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
   all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
-  status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
+  status = "complete" if all(p.status == "complete" for p in file_progress.values()) else "in_progress" if any(p.status == "in_progress" for p in file_progress.values()) else "not_started"
   return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
 
 async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
@@ -147,12 +147,12 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
       downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
       speed = downloaded_this_session / (time.time() - start_time)
       eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
-      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time)
+      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
       on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
       if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
     for file in filtered_file_list:
       downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
-      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
+      file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
 
     semaphore = asyncio.Semaphore(max_parallel_downloads)
     async def download_with_semaphore(file):

+ 5 - 4
exo/inference/tinygrad/inference.py

@@ -61,12 +61,13 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
 
   return model
 
+_executor = ThreadPoolExecutor(max_workers=1) # singleton so tinygrad always runs on the same thread
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
-    self.executor = ThreadPoolExecutor(max_workers=1)
     self.states = OrderedDict()
+    self.executor = _executor
 
   def poll_state(self, x, request_id: str, max_states=2):
     if request_id not in self.states:
@@ -79,8 +80,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     return {"start_pos": state.start, "cache": state.cache}
 
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
-    logits = x[:, -1, :]
     def sample_wrapper():
+      logits = x[:, -1, :]
       return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
     return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
@@ -112,9 +113,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       state = self.poll_state(h, request_id)
       out = self.model.forward(h, **state)
       self.states[request_id].start += x.shape[1]
-      return out.realize()
+      return out.numpy()
     output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
-    return output_data.numpy(), inference_state
+    return output_data, inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
     def step(x, y, l):

+ 7 - 5
exo/main.py

@@ -206,14 +206,16 @@ def preemptively_load_shard(request_id: str, opaque_status: str):
       traceback.print_exc()
 node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
 
-last_broadcast_time = 0
+last_events: dict[str, tuple[float, RepoProgressEvent]] = {}
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-  global last_broadcast_time
+  global last_events
   current_time = time.time()
   if event.status == "not_started": return
-  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-    last_broadcast_time = current_time
-    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+  last_event = last_events.get(shard.model_id)
+  if last_event and last_event[1].status == "complete" and event.status == "complete": return
+  if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
+  last_events[shard.model_id] = (current_time, event)
+  asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 async def run_model_cli(node: Node, model_name: str, prompt: str):

+ 47 - 45
exo/viz/topology_viz.py

@@ -89,16 +89,16 @@ class TopologyViz:
     # Calculate available height for content
     panel_height = 15  # Fixed panel height
     available_lines = panel_height - 2  # Subtract 2 for panel borders
-    lines_per_entry = available_lines // len(requests) if requests else 0
+    lines_per_request = available_lines // len(requests) if requests else 0
 
     for (prompt, output) in reversed(requests):
       prompt_icon, output_icon = "💬️", "🤖"
 
-      # Calculate max lines for prompt and output
-      max_prompt_lines = max(3, lines_per_entry // 2)  # Ensure at least 3 lines for prompt
-      max_output_lines = lines_per_entry - max_prompt_lines - 1  # Remaining space minus spacing
+      # Equal space allocation for prompt and output
+      max_prompt_lines = lines_per_request // 2
+      max_output_lines = lines_per_request - max_prompt_lines - 1  # -1 for spacing
 
-      # Process prompt with more generous line allocation
+      # Process prompt
       prompt_lines = []
       for line in prompt.split('\n'):
         words = line.split()
@@ -118,53 +118,55 @@ class TopologyViz:
         if current_line:
           prompt_lines.append(' '.join(current_line))
 
-      # Show more prompt content and append ellipses to last line if needed
+      # Truncate prompt if needed
       if len(prompt_lines) > max_prompt_lines:
         prompt_lines = prompt_lines[:max_prompt_lines]
-        # Append ellipses to last line if there's room, otherwise truncate last line
-        last_line = prompt_lines[-1]
-        if len(last_line) + 4 <= max_width:  # +4 for " ..."
-          prompt_lines[-1] = last_line + " ..."
-        else:
-          prompt_lines[-1] = last_line[:max_width-4] + " ..."
+        if prompt_lines:
+          last_line = prompt_lines[-1]
+          if len(last_line) + 4 <= max_width:
+            prompt_lines[-1] = last_line + " ..."
+          else:
+            prompt_lines[-1] = last_line[:max_width-4] + " ..."
 
       prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
       prompt_text.append('\n'.join(prompt_lines), style="white")
+      content.append(prompt_text)
 
-      # Process output - same word-aware wrapping
-      output_lines = []
-      for line in output.split('\n'):
-        words = line.split()
-        current_line = []
-        current_length = 0
-
-        for word in words:
-          if current_length + len(word) + 1 <= max_width:
-            current_line.append(word)
-            current_length += len(word) + 1
-          else:
-            if current_line:
-              output_lines.append(' '.join(current_line))
-            current_line = [word]
-            current_length = len(word)
-
-        if current_line:
-          output_lines.append(' '.join(current_line))
-
-      if len(output_lines) > max_output_lines:
-        output_lines = output_lines[:max_output_lines]
-        last_line = output_lines[-1] if output_lines else None
-        if last_line:
-          if len(last_line) + 4 <= max_width:
-            output_lines[-1] = last_line + " ..."
-          else:
-            output_lines[-1] = last_line[:max_width-4] + " ..."
-
-      output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
-      output_text.append('\n'.join(output_lines), style="white")
+      # Process output with similar word wrapping
+      if output:  # Only process output if it exists
+        output_lines = []
+        for line in output.split('\n'):
+          words = line.split()
+          current_line = []
+          current_length = 0
+
+          for word in words:
+            if current_length + len(word) + 1 <= max_width:
+              current_line.append(word)
+              current_length += len(word) + 1
+            else:
+              if current_line:
+                output_lines.append(' '.join(current_line))
+              current_line = [word]
+              current_length = len(word)
+
+          if current_line:
+            output_lines.append(' '.join(current_line))
+
+        # Truncate output if needed
+        if len(output_lines) > max_output_lines:
+          output_lines = output_lines[:max_output_lines]
+          if output_lines:
+            last_line = output_lines[-1]
+            if len(last_line) + 4 <= max_width:
+              output_lines[-1] = last_line + " ..."
+            else:
+              output_lines[-1] = last_line[:max_width-4] + " ..."
+
+        output_text = Text(f"{output_icon} ", style="bold bright_magenta")
+        output_text.append('\n'.join(output_lines), style="white")
+        content.append(output_text)
 
-      content.append(prompt_text)
-      content.append(output_text)
       content.append(Text())  # Empty line between entries
 
     return Panel(

+ 1 - 1
setup.py

@@ -29,7 +29,7 @@ install_requires = [
   "transformers==4.46.3",
   "uuid==1.30",
   "uvloop==0.21.0",
-  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
+  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8",
 ]
 
 extras_require = {