Pranav Veldurthi 5 miesięcy temu
rodzic
commit
686e139508

+ 2 - 0
.gitattributes

@@ -0,0 +1,2 @@
+*.mp3 filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text

+ 2 - 2
README.md

@@ -182,10 +182,10 @@ curl http://localhost:52415/v1/chat/completions \
 #### Device 1 (MacOS):
 
 ```sh
-exo --inference-engine tinygrad
+exo
 ```
 
-Here we explicitly tell exo to use the **tinygrad** inference engine.
+Note: We don't need to explicitly tell exo to use the **tinygrad** inference engine. **MLX** and **tinygrad** are interoperable!
 
 #### Device 2 (Linux):
 ```sh

BIN
docs/exo-logo-transparent-black-text.png


BIN
docs/exo-logo-transparent.png


BIN
docs/exo-rounded.png


BIN
docs/exo-screenshot.png


BIN
docs/ring-topology.png


+ 64 - 5
exo/api/chatgpt_api.py

@@ -8,21 +8,21 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
-import os
 import signal
-import sys
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
+from exo.apputil import create_animation_mp4
 from typing import Callable, Optional
 from PIL import Image
 import numpy as np
 import base64
 from io import BytesIO
 import mlx.core as mx
+import tempfile
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -181,6 +181,8 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
+    cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
+    cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
 
       
     if "__compiled__" not in globals():
@@ -191,7 +193,7 @@ class ChatGPTAPI:
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
-  
+
   async def handle_quit(self, request):
     if DEBUG>=1: print("Received quit signal")
     response = web.json_response({"detail": "Quit signal received"}, status=200)
@@ -224,11 +226,11 @@ class ChatGPTAPI:
   async def handle_model_support(self, request):
     return web.json_response({
       "model pool": {
-        model_name: pretty_name.get(model_name, model_name) 
+        model_name: pretty_name.get(model_name, model_name)
         for model_name in get_supported_models(self.node.topology_inference_engines_pool)
       }
     })
-  
+
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
@@ -457,6 +459,63 @@ class ChatGPTAPI:
         if DEBUG >= 2: traceback.print_exc()
         return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
   
+  async def handle_create_animation(self, request):
+    try:
+      data = await request.json()
+      replacement_image_path = data.get("replacement_image_path")
+      device_name = data.get("device_name", "Local Device")
+      prompt_text = data.get("prompt", "")
+
+      if DEBUG >= 2: print(f"Creating animation with params: replacement_image={replacement_image_path}, device={device_name}, prompt={prompt_text}")
+
+      if not replacement_image_path:
+        return web.json_response({"error": "replacement_image_path is required"}, status=400)
+
+      # Create temp directory if it doesn't exist
+      tmp_dir = Path(tempfile.gettempdir())/"exo_animations"
+      tmp_dir.mkdir(parents=True, exist_ok=True)
+
+      # Generate unique output filename in temp directory
+      output_filename = f"animation_{uuid.uuid4()}.mp4"
+      output_path = str(tmp_dir/output_filename)
+
+      if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
+
+      # Create the animation
+      create_animation_mp4(
+        replacement_image_path,
+        output_path,
+        device_name,
+        prompt_text
+      )
+
+      return web.json_response({
+        "status": "success",
+        "output_path": output_path
+      })
+
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
+  async def handle_post_download(self, request):
+    try:
+      data = await request.json()
+      model_name = data.get("model")
+      if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
+      if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
+      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
+
+      return web.json_response({
+        "status": "success",
+        "message": f"Download started for model: {model_name}"
+      })
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()

+ 1 - 0
exo/apputil/__init__.py

@@ -0,0 +1 @@
+from exo.apputil.anim import create_animation_mp4

+ 161 - 0
exo/apputil/anim.py

@@ -0,0 +1,161 @@
+from PIL import Image, ImageDraw, ImageFont, ImageFilter
+import os
+import numpy as np
+import cv2
+
+def draw_rounded_rectangle(draw, coords, radius, fill):
+  left, top, right, bottom = coords
+  diameter = radius * 2
+  draw.rectangle([left + radius, top, right - radius, bottom], fill=fill)
+  draw.rectangle([left, top + radius, right, bottom - radius], fill=fill)
+  draw.pieslice([left, top, left + diameter, top + diameter], 180, 270, fill=fill)
+  draw.pieslice([right - diameter, top, right, top + diameter], 270, 360, fill=fill)
+  draw.pieslice([left, bottom - diameter, left + diameter, bottom], 90, 180, fill=fill)
+  draw.pieslice([right - diameter, bottom - diameter, right, bottom], 0, 90, fill=fill)
+
+def draw_centered_text_rounded(draw, text, font, rect_coords, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_width = bbox[2] - bbox[0]
+  text_height = bbox[3] - bbox[1]
+  rect_left, rect_top, rect_right, rect_bottom = rect_coords
+  rect_width = rect_right - rect_left
+  rect_height = rect_bottom - rect_top
+  text_x = rect_left + (rect_width - text_width) // 2
+  text_y = rect_top + (rect_height - text_height) // 2
+  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+
+def draw_left_aligned_text_rounded(draw, text, font, rect_coords, padding_left=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_height = bbox[3] - bbox[1]
+  rect_left, rect_top, rect_right, rect_bottom = rect_coords
+  rect_height = rect_bottom - rect_top
+  text_y = rect_top + (rect_height - text_height) // 2
+  text_x = rect_left + padding_left
+  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+
+def draw_right_text_dynamic_width_rounded(draw, text, font, base_coords, padding=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_width = bbox[2] - bbox[0]
+  text_height = bbox[3] - bbox[1]
+  _, rect_top, rect_right, rect_bottom = base_coords
+  rect_height = rect_bottom - rect_top
+  new_rect_left = rect_right - (text_width + (padding * 2))
+  text_y = rect_top + (rect_height - text_height) // 2
+  text_x = new_rect_left + padding
+  draw_rounded_rectangle(draw, (new_rect_left, rect_top, rect_right, rect_bottom), radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+  return new_rect_left
+
+def draw_progress_bar(draw, progress, coords, color="yellow", bg_color=(70, 70, 70)):
+  left, top, right, bottom = coords
+  total_width = right - left
+  draw.rectangle(coords, fill=bg_color)
+  progress_width = int(total_width * progress)
+  if progress_width > 0:
+    draw.rectangle((left, top, left + progress_width, bottom), fill=color)
+
+def crop_image(image, top_crop=70):
+  width, height = image.size
+  return image.crop((0, top_crop, width, height))
+
+def create_animation_mp4(
+  replacement_image_path,
+  output_path,
+  device_name,
+  prompt_text,
+  fps=30,
+  target_size=(512, 512),
+  target_position=(139, 755),
+  progress_coords=(139, 1285, 655, 1295),
+  device_coords=(1240, 370, 1640, 416),
+  prompt_coords=(332, 1702, 2662, 1745)
+):
+  frames = []
+  try:
+    font = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 20)
+    promptfont = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 24)
+  except:
+    font = ImageFont.load_default()
+    promptfont = ImageFont.load_default()
+
+  # Process first frame
+  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
+  draw = ImageDraw.Draw(base_img)
+  draw_centered_text_rounded(draw, device_name, font, device_coords)
+  frames.extend([crop_image(base_img)] * 30)  # 1 second at 30fps
+
+  # Process second frame with typing animation
+  base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
+  for i in range(len(prompt_text) + 1):
+    current_frame = base_img2.copy()
+    draw = ImageDraw.Draw(current_frame)
+    draw_centered_text_rounded(draw, device_name, font, device_coords)
+    if i > 0:  # Only draw if we have at least one character
+      draw_left_aligned_text_rounded(draw, prompt_text[:i], promptfont, prompt_coords)
+    frames.extend([crop_image(current_frame)] * 2)  # 2 frames per character for smooth typing
+  
+  # Hold the complete prompt for a moment
+  frames.extend([frames[-1]] * 30)  # Hold for 1 second
+
+  # Create blur sequence
+  replacement_img = Image.open(replacement_image_path)
+  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
+  blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
+
+  for i, blur_amount in enumerate(blur_steps):
+    new_frame = base_img.copy()
+    draw = ImageDraw.Draw(new_frame)
+
+    replacement_copy = replacement_img.copy()
+    replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
+    if blur_amount > 0:
+      replacement_copy = replacement_copy.filter(ImageFilter.GaussianBlur(radius=blur_amount))
+
+    mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
+    new_frame.paste(replacement_copy, target_position, mask)
+
+    draw_progress_bar(draw, (i + 1) / 9, progress_coords)
+    draw_centered_text_rounded(draw, device_name, font, device_coords)
+    draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
+
+    frames.extend([crop_image(new_frame)] * 15)  # 0.5 seconds at 30fps
+
+  # Create and add final frame (image4)
+  final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
+  draw = ImageDraw.Draw(final_base)
+
+  draw_centered_text_rounded(draw, device_name, font, device_coords)
+  draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
+
+  replacement_copy = replacement_img.copy()
+  replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
+  mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
+  final_base.paste(replacement_copy, target_position, mask)
+
+  frames.extend([crop_image(final_base)] * 30)  # 1 second at 30fps
+
+  # Convert frames to video using H.264 codec
+  if frames:
+    first_frame = np.array(frames[0])
+    height, width = first_frame.shape[:2]
+    fourcc = cv2.VideoWriter_fourcc(*'avc1')
+    out = cv2.VideoWriter(
+      output_path,
+      fourcc,
+      fps,
+      (width, height),
+      isColor=True
+    )
+
+    if not out.isOpened():
+      print("Error: VideoWriter failed to open")
+      return
+
+    for frame in frames:
+      frame_array = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
+      out.write(frame_array)
+    
+    out.release()
+    print(f"Video saved successfully to {output_path}")

+ 3 - 0
exo/apputil/baseimages/image1.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:361fdadd67c277d45cd18b0bfc8c5ceea5fd89f2d65aef157fd915ce9cbb8599
+size 814460

+ 3 - 0
exo/apputil/baseimages/image2.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0e3891bc6b4f4dfa7444af53fcaa4b3ba06b0549546202be3243f08a0e6bd7e
+size 814235

+ 3 - 0
exo/apputil/baseimages/image3.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2dc5b3378aef397d60fd1252da8a1c578ad97e202a859590ffa416b49551d19
+size 146633

+ 3 - 0
exo/apputil/baseimages/image4.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dbc6883e2a3c5233ec7b844c98646922bdc4f5e42e1f424857eaff56f785dbcd
+size 668550

+ 1 - 1
exo/inference/dummy_inference_engine.py

@@ -18,7 +18,7 @@ class DummyInferenceEngine(InferenceEngine):
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     return np.array(self.tokenizer.encode(prompt))
   
-  async def sample(self, x: np.ndarray) -> np.ndarray:
+  async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
     if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
     return x
 

+ 3 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -1,15 +1,16 @@
 import numpy as np
 import mlx.core as mx
 import mlx.nn as nn
+from mlx_lm.sample_utils import top_p_sampling
 from ..inference_engine import InferenceEngine
 from .stateful_model import StatefulModel
-from .sharded_utils import load_shard, get_image_from_str
+from .sharded_utils import load_shard
 from ..shard import Shard
 from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
-from functools import partial
+
 def sample_logits(
   logits: mx.array,
   temp: float = 0.0,

+ 6 - 1
exo/inference/tinygrad/tinygrad_helpers.py

@@ -7,6 +7,7 @@ from exo.inference.shard import Shard
 from exo.helpers import DEBUG
 from exo.download.hf.hf_helpers import get_allow_patterns
 from fnmatch import fnmatch
+import re
 
 
 # **** helper functions ****
@@ -42,6 +43,10 @@ def load(fn: str, shard: Shard):
     if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}
   elif fn.endswith(".safetensors"):
-    return safe_load(fn)
+    weight_map = safe_load(fn)
+    for k in list(weight_map):
+      if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
+          del weight_map[k]
+    return weight_map
   else:
     return torch_load(fn)

+ 11 - 3
exo/main.py

@@ -55,8 +55,10 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
+parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
+parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
@@ -88,6 +90,9 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
+# Convert node-id-filter to list if provided
+allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
+
 if args.discovery_module == "udp":
   discovery = UDPDiscovery(
     args.node_id,
@@ -95,7 +100,8 @@ if args.discovery_module == "udp":
     args.listen_port,
     args.broadcast_port,
     lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
-    discovery_timeout=args.discovery_timeout
+    discovery_timeout=args.discovery_timeout,
+    allowed_node_ids=allowed_node_ids
   )
 elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(
@@ -104,7 +110,8 @@ elif args.discovery_module == "tailscale":
     lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     tailscale_api_key=args.tailscale_api_key,
-    tailnet=args.tailnet_name
+    tailnet=args.tailnet_name,
+    allowed_node_ids=allowed_node_ids
   )
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
@@ -119,7 +126,8 @@ node = StandardNode(
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   topology_viz=topology_viz,
-  shard_downloader=shard_downloader
+  shard_downloader=shard_downloader,
+  default_sample_temperature=args.default_temp
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server

+ 6 - 0
exo/networking/tailscale/tailscale_discovery.py

@@ -21,6 +21,7 @@ class TailscaleDiscovery(Discovery):
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     tailscale_api_key: str = None,
     tailnet: str = None,
+    allowed_node_ids: List[str] = None,
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -34,6 +35,7 @@ class TailscaleDiscovery(Discovery):
     self.cleanup_task = None
     self.tailscale_api_key = tailscale_api_key
     self.tailnet = tailnet
+    self.allowed_node_ids = allowed_node_ids
     self._device_id = None
     self.update_task = None
 
@@ -84,6 +86,10 @@ class TailscaleDiscovery(Discovery):
             if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
             continue
 
+          if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
+            if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
+            continue
+
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
             new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
             if not await new_peer_handle.health_check():

+ 8 - 0
exo/networking/udp/udp_discovery.py

@@ -45,6 +45,7 @@ class UDPDiscovery(Discovery):
     broadcast_interval: int = 1,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    allowed_node_ids: List[str] = None,
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -54,6 +55,7 @@ class UDPDiscovery(Discovery):
     self.broadcast_interval = broadcast_interval
     self.discovery_timeout = discovery_timeout
     self.device_capabilities = device_capabilities
+    self.allowed_node_ids = allowed_node_ids
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
     self.broadcast_task = None
     self.listen_task = None
@@ -133,6 +135,12 @@ class UDPDiscovery(Discovery):
 
     if message["type"] == "discovery" and message["node_id"] != self.node_id:
       peer_id = message["node_id"]
+      
+      # Skip if peer_id is not in allowed list
+      if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
+        if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
+        return
+
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_prio = message["priority"]

+ 3 - 1
exo/orchestration/standard_node.py

@@ -27,6 +27,7 @@ class StandardNode(Node):
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
+    default_sample_temperature: float = 0.0,
     topology_viz: Optional[TopologyViz] = None,
     shard_downloader: Optional[HFShardDownloader] = None,
   ):
@@ -43,6 +44,7 @@ class StandardNode(Node):
     self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
+    self.default_sample_temperature = default_sample_temperature
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
@@ -114,7 +116,7 @@ class StandardNode(Node):
         self.buffered_token_output[request_id] = ([], False)
       is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
       if shard.is_last_layer() and not is_finished:
-        token = await self.inference_engine.sample(result)
+        token = await self.inference_engine.sample(result, temp = self.default_sample_temperature)
         self.buffered_token_output[request_id][0].append(token.item())
         intermediate_result = self.buffered_token_output[request_id][0]
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")

+ 138 - 56
extra/dashboard/dashboard.py

@@ -10,7 +10,7 @@ from pathlib import Path
 from plotly.subplots import make_subplots
 import plotly.graph_objects as go
 import time
-import simpleaudio as sa
+import pygame.mixer
 from datetime import datetime
 
 class AsyncCircleCIClient:
@@ -39,15 +39,7 @@ class AsyncCircleCIClient:
     ):
         """
         Get recent pipelines for a project with pagination support
-
-        Args:
-            session: aiohttp client session
-            org_slug: Organization slug
-            page_token: Token for pagination
-            limit: Maximum number of pipelines to return
-            branch: Specific branch to fetch pipelines from
         """
-
         params = {
             "branch": branch,
             "page-token": page_token
@@ -62,13 +54,17 @@ class AsyncCircleCIClient:
 
         next_page_token = data.get("next_page_token")
 
+        # If we have a limit, check if we need more pages
+        if limit and len(pipelines) >= limit:
+            return pipelines
+
         # If there are more pages and we haven't hit the limit, recursively get them
-        if next_page_token and (limit is None or len(pipelines) < limit):
+        if next_page_token:
             next_pipelines = await self.get_recent_pipelines(
                 session,
                 org_slug,
                 page_token=next_page_token,
-                limit=limit,
+                limit=limit - len(pipelines) if limit else None,  # Adjust limit for next page
                 branch=branch
             )
             pipelines.extend(next_pipelines)
@@ -108,20 +104,30 @@ class PackageSizeTracker:
         self.client = AsyncCircleCIClient(token, project_slug)
         self.logger = logging.getLogger("PackageSizeTracker")
         self.last_data_hash = None
-        self.notification_sound_path = Path(__file__).parent / "notification.wav"
         self.debug = debug
 
-        # Sound file paths - replace these with your actual sound files
+        # Initialize pygame mixer
+        pygame.mixer.init()
+
+        # Sound file paths - can use MP3 files with pygame
         sounds_dir = Path(__file__).parent / "sounds"
         self.sounds = {
-            'lines_up': sounds_dir / "lines_increased.wav",
-            'lines_down': sounds_dir / "lines_decreased.wav",
-            'tokens_up': sounds_dir / "tokens_increased.wav",
-            'tokens_down': sounds_dir / "tokens_decreased.wav",
-            'size_up': sounds_dir / "size_increased.wav",
-            'size_down': sounds_dir / "size_decreased.wav"
+            'lines_up': sounds_dir / "gta5_wasted.mp3",
+            'lines_down': sounds_dir / "pokemon_evolve.mp3",
+            'tokens_up': sounds_dir / "pokemon_evolve.mp3",
+            'tokens_down': sounds_dir / "gta5_wasted.mp3",
+            'size_up': sounds_dir / "gta5_wasted.mp3",
+            'size_down': sounds_dir / "pokemon_evolve.mp3"
         }
 
+    def test_sound_effects(self):
+        """Test all sound effects with a small delay between each"""
+        self.logger.info("Testing sound effects...")
+        for sound_key in self.sounds:
+            self.logger.info(f"Playing {sound_key}")
+            self._play_sound(sound_key)
+            time.sleep(1)  # Wait 1 second between sounds
+
     def setup_logging(self, debug: bool):
         level = logging.DEBUG if debug else logging.INFO
         logging.basicConfig(
@@ -274,16 +280,57 @@ class PackageSizeTracker:
             self.logger.error(f"Error processing pipeline {pipeline['id']}: {str(e)}")
             return None
 
+    async def process_pipeline_batch(
+        self,
+        session: aiohttp.ClientSession,
+        pipelines: List[Dict],
+        batch_size: int = 5
+    ) -> List[Dict]:
+        """
+        Process a batch of pipelines with rate limiting.
+
+        Args:
+            session: aiohttp client session
+            pipelines: List of pipelines to process
+            batch_size: Number of pipelines to process in parallel
+
+        Returns:
+            List of processed pipeline data points
+        """
+        data_points = []
+
+        for i in range(0, len(pipelines), batch_size):
+            batch = pipelines[i:i + batch_size]
+
+            # Process batch in parallel
+            tasks = [self.process_pipeline(session, pipeline) for pipeline in batch]
+            batch_results = await asyncio.gather(*tasks)
+
+            # Filter out None results
+            batch_data = [r for r in batch_results if r is not None]
+            data_points.extend(batch_data)
+
+            # Add delay between batches if there are more to process
+            if i + batch_size < len(pipelines):
+                await asyncio.sleep(1)  # 1 second delay between batches
+
+        return data_points
+
     async def collect_data(self) -> List[Dict]:
         self.logger.info("Starting data collection...")
         async with aiohttp.ClientSession(headers=self.client.headers) as session:
-            # Get pipelines from both main and circleci branches
+            # Get pipelines from main branch
             main_pipelines = await self.client.get_recent_pipelines(
                 session,
                 org_slug=self.client.project_slug,
                 limit=20,
                 branch="main"
             )
+
+            # Add delay between branch requests
+            await asyncio.sleep(2)
+
+            # Get pipelines from circleci branch
             circleci_pipelines = await self.client.get_recent_pipelines(
                 session,
                 org_slug=self.client.project_slug,
@@ -291,18 +338,27 @@ class PackageSizeTracker:
                 branch="circleci"
             )
 
+            # Combine pipelines and sort by created_at date
             pipelines = main_pipelines + circleci_pipelines
-            # Sort pipelines by created_at date
-            pipelines.sort(key=lambda x: x.get("created_at", x.get("updated_at")), reverse=True)
+            pipelines.sort(
+                key=lambda x: datetime.fromisoformat(
+                    x.get("created_at", x.get("updated_at")).replace('Z', '+00:00')
+                ),
+                reverse=True  # Most recent first
+            )
 
             self.logger.info(f"Found {len(pipelines)} recent pipelines")
 
-            # Process all pipelines in parallel
-            tasks = [self.process_pipeline(session, pipeline) for pipeline in pipelines]
-            results = await asyncio.gather(*tasks)
+            # Process pipelines in batches
+            data_points = await self.process_pipeline_batch(session, pipelines)
 
-            # Filter out None results
-            data_points = [r for r in results if r is not None]
+            # Sort by timestamp
+            data_points.sort(
+                key=lambda x: datetime.fromisoformat(
+                    x.get("timestamp").replace('Z', '+00:00')
+                ),
+                reverse=True  # Most recent first
+            )
 
         return data_points
 
@@ -931,60 +987,86 @@ class PackageSizeTracker:
         ])))
 
     def _play_sound(self, sound_key: str):
-        """Play a specific notification sound"""
+        """Play a specific notification sound using pygame"""
         try:
             sound_path = self.sounds.get(sound_key)
             if sound_path and sound_path.exists():
-                wave_obj = sa.WaveObject.from_wave_file(str(sound_path))
-                wave_obj.play()
+                sound = pygame.mixer.Sound(str(sound_path))
+                sound.play()
+                # Wait for the sound to finish playing
+                pygame.time.wait(int(sound.get_length() * 1000))
             else:
                 self.logger.warning(f"Sound file not found: {sound_key} at {sound_path}")
         except Exception as e:
             self.logger.error(f"Failed to play sound {sound_key}: {e}")
 
     def _check_metrics_changes(self, current_data: List[Dict], previous_data: List[Dict]):
-        """Check for specific metric changes and play appropriate sounds"""
-        if not previous_data:
-            return
+        # Sort data by timestamp in descending order (most recent first)
+        def sort_by_timestamp(data):
+            return sorted(
+                data,
+                key=lambda x: x.get('timestamp', ''),
+                reverse=True  # Most recent first
+            )
+
+        current_data = sort_by_timestamp(current_data)
+        previous_data = sort_by_timestamp(previous_data)
 
-        # Get latest data points
-        current = current_data[-1]
-        previous = previous_data[-1]
+        # Helper to find latest entry with a specific metric
+        def find_latest_with_metric(data: List[Dict], metric: str) -> Optional[Dict]:
+            return next((d for d in data if metric in d), None)
 
         # Check line count changes
-        if 'total_lines' in current and 'total_lines' in previous:
-            diff = current['total_lines'] - previous['total_lines']
+        current_lines = find_latest_with_metric(current_data, 'total_lines')
+        previous_lines = find_latest_with_metric(previous_data, 'total_lines')
+
+        if current_lines and previous_lines:
+            diff = current_lines['total_lines'] - previous_lines['total_lines']
+            self.logger.debug(f"Lines of code diff: {diff}")
             if diff > 0:
                 self.logger.info(f"Lines of code increased by {diff:,}")
                 self._play_sound('lines_up')
             elif diff < 0:
                 self.logger.info(f"Lines of code decreased by {abs(diff):,}")
                 self._play_sound('lines_down')
+        else:
+            self.logger.debug("No lines of code data found")
 
         # Check tokens per second changes
-        if 'tokens_per_second' in current and 'tokens_per_second' in previous:
-            diff = current['tokens_per_second'] - previous['tokens_per_second']
+        current_tokens = find_latest_with_metric(current_data, 'tokens_per_second')
+        previous_tokens = find_latest_with_metric(previous_data, 'tokens_per_second')
+
+        if current_tokens and previous_tokens:
+            diff = current_tokens['tokens_per_second'] - previous_tokens['tokens_per_second']
+            self.logger.debug(f"Tokens per second diff: {diff}")
             if diff > 0:
                 self.logger.info(f"Tokens per second increased by {diff:.2f}")
                 self._play_sound('tokens_up')
             elif diff < 0:
                 self.logger.info(f"Tokens per second decreased by {abs(diff):.2f}")
                 self._play_sound('tokens_down')
+        else:
+            self.logger.debug("No tokens per second data found")
 
         # Check package size changes
-        if 'total_size_mb' in current and 'total_size_mb' in previous:
-            diff = current['total_size_mb'] - previous['total_size_mb']
+        current_size = find_latest_with_metric(current_data, 'total_size_mb')
+        previous_size = find_latest_with_metric(previous_data, 'total_size_mb')
+
+        if current_size and previous_size:
+            diff = current_size['total_size_mb'] - previous_size['total_size_mb']
+            self.logger.debug(f"Package size diff: {diff:.2f}MB")
             if diff > 0:
                 self.logger.info(f"Package size increased by {diff:.2f}MB")
                 self._play_sound('size_up')
             elif diff < 0:
                 self.logger.info(f"Package size decreased by {abs(diff):.2f}MB")
                 self._play_sound('size_down')
+        else:
+            self.logger.debug("No package size data found")
 
-    async def run_dashboard(self, update_interval: int = 30):
+    async def run_dashboard(self, update_interval: int = 10):
         """Run the dashboard with periodic updates"""
         try:
-            # Force convert update_interval to float and log its type
             update_interval = float(update_interval)
             self.logger.debug(f"Update interval type: {type(update_interval)}, value: {update_interval}")
         except ValueError as e:
@@ -997,7 +1079,6 @@ class PackageSizeTracker:
         while True:
             try:
                 start_time = time.time()
-                self.logger.debug(f"Start time type: {type(start_time)}, value: {start_time}")
 
                 # Collect new data
                 current_data = await self.collect_data()
@@ -1013,25 +1094,26 @@ class PackageSizeTracker:
                         f"Dashboard updated at {datetime.now().strftime('%H:%M:%S')}"
                     )
 
-                    # Check for metric changes and play appropriate sounds
-                    self._check_metrics_changes(current_data, previous_data)
+                    print("Curr:", len(current_data))
+                    print("Prev:", len(previous_data) if previous_data else "None")
+                    if previous_data:
+                        # Check for metric changes and play appropriate sounds
+                        self.logger.debug(f"Checking metrics changes between {len(current_data)} current and {len(previous_data)} previous data points")
+                        self._check_metrics_changes(current_data, previous_data)
 
                 # Update previous data
-                previous_data = current_data
+                previous_data = current_data.copy()  # Make a copy to prevent reference issues
 
-                # Calculate sleep time with explicit type conversion and logging
+                # Calculate sleep time
                 elapsed = float(time.time() - start_time)
-                self.logger.debug(f"Elapsed time type: {type(elapsed)}, value: {elapsed}")
-                sleep_time = max(0.0, float(update_interval) - elapsed)
-                self.logger.debug(f"Sleep time type: {type(sleep_time)}, value: {sleep_time}")
-
+                sleep_time = max(0.0, update_interval - elapsed)
                 await asyncio.sleep(sleep_time)
 
             except Exception as e:
                 self.logger.error(f"Error in dashboard update loop: {e}", exc_info=True)
                 if self.debug:
                     raise
-                await asyncio.sleep(float(update_interval))
+                await asyncio.sleep(update_interval)
 
 async def main():
     token = os.getenv("CIRCLECI_TOKEN")
@@ -1040,11 +1122,11 @@ async def main():
 
     try:
         # Get update interval from environment or use default
-        update_interval = float(os.getenv("UPDATE_INTERVAL", "30"))
+        update_interval = float(os.getenv("UPDATE_INTERVAL", "10"))
         print(f"Update interval type: {type(update_interval)}, value: {update_interval}")  # Debug print
     except ValueError as e:
         print(f"Error converting UPDATE_INTERVAL to float: {os.getenv('UPDATE_INTERVAL')}")
-        update_interval = 30.0
+        update_interval = 10.0
 
     if not token or not project_slug:
         print("Error: Please set CIRCLECI_TOKEN and CIRCLECI_PROJECT_SLUG environment variables")

+ 1 - 1
extra/dashboard/requirements.txt

@@ -2,4 +2,4 @@ plotly
 pandas
 requests
 aiohttp
-simpleaudio
+pygame

+ 3 - 0
extra/dashboard/sounds/gta5_wasted.mp3

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb3fb66dd02827fbff86ef1ce3bc6438371c823aed7d4c3803ed522f008e4947
+size 206399

+ 3 - 0
extra/dashboard/sounds/pokemon_evolve.mp3

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d99cc9bdab4a4639d50f439b424547000e7c79f195b5b121734ad4ead435911c
+size 633345

+ 7 - 1
install.sh

@@ -1,5 +1,11 @@
 #!/bin/bash
 
-python3 -m venv .venv
+if command -v python3.12 &>/dev/null; then
+    echo "Python 3.12 is installed, proceeding with python3.12..."
+    python3.12 -m venv .venv
+else
+    echo "The recommended version of Python to run exo with is Python 3.12, but $(python3 --version) is installed. Proceeding with $(python3 --version)"
+    python3 -m venv .venv
+fi
 source .venv/bin/activate
 pip install -e .

+ 1 - 0
setup.py

@@ -15,6 +15,7 @@ install_requires = [
   "numpy==2.0.0",
   "nuitka==2.5.1",
   "nvidia-ml-py==12.560.30",
+  "opencv-python==4.10.0.84",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
   "protobuf==5.28.1",