Browse Source

Merge Latest

Pranav Veldurthi 5 months ago
parent
commit
686e139508

+ 2 - 0
.gitattributes

@@ -0,0 +1,2 @@
+*.mp3 filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text

+ 2 - 2
README.md

@@ -182,10 +182,10 @@ curl http://localhost:52415/v1/chat/completions \
 #### Device 1 (MacOS):
 #### Device 1 (MacOS):
 
 
 ```sh
 ```sh
-exo --inference-engine tinygrad
+exo
 ```
 ```
 
 
-Here we explicitly tell exo to use the **tinygrad** inference engine.
+Note: We don't need to explicitly tell exo to use the **tinygrad** inference engine. **MLX** and **tinygrad** are interoperable!
 
 
 #### Device 2 (Linux):
 #### Device 2 (Linux):
 ```sh
 ```sh

BIN
docs/exo-logo-transparent-black-text.png


BIN
docs/exo-logo-transparent.png


BIN
docs/exo-rounded.png


BIN
docs/exo-screenshot.png


BIN
docs/ring-topology.png


+ 64 - 5
exo/api/chatgpt_api.py

@@ -8,21 +8,21 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
 import traceback
 import traceback
-import os
 import signal
 import signal
-import sys
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
+from exo.apputil import create_animation_mp4
 from typing import Callable, Optional
 from typing import Callable, Optional
 from PIL import Image
 from PIL import Image
 import numpy as np
 import numpy as np
 import base64
 import base64
 from io import BytesIO
 from io import BytesIO
 import mlx.core as mx
 import mlx.core as mx
+import tempfile
 
 
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -181,6 +181,8 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
     cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
+    cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
+    cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
 
 
       
       
     if "__compiled__" not in globals():
     if "__compiled__" not in globals():
@@ -191,7 +193,7 @@ class ChatGPTAPI:
 
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
     self.app.middlewares.append(self.log_request)
-  
+
   async def handle_quit(self, request):
   async def handle_quit(self, request):
     if DEBUG>=1: print("Received quit signal")
     if DEBUG>=1: print("Received quit signal")
     response = web.json_response({"detail": "Quit signal received"}, status=200)
     response = web.json_response({"detail": "Quit signal received"}, status=200)
@@ -224,11 +226,11 @@ class ChatGPTAPI:
   async def handle_model_support(self, request):
   async def handle_model_support(self, request):
     return web.json_response({
     return web.json_response({
       "model pool": {
       "model pool": {
-        model_name: pretty_name.get(model_name, model_name) 
+        model_name: pretty_name.get(model_name, model_name)
         for model_name in get_supported_models(self.node.topology_inference_engines_pool)
         for model_name in get_supported_models(self.node.topology_inference_engines_pool)
       }
       }
     })
     })
-  
+
   async def handle_get_models(self, request):
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
 
@@ -457,6 +459,63 @@ class ChatGPTAPI:
         if DEBUG >= 2: traceback.print_exc()
         if DEBUG >= 2: traceback.print_exc()
         return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
         return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
   
   
+  async def handle_create_animation(self, request):
+    try:
+      data = await request.json()
+      replacement_image_path = data.get("replacement_image_path")
+      device_name = data.get("device_name", "Local Device")
+      prompt_text = data.get("prompt", "")
+
+      if DEBUG >= 2: print(f"Creating animation with params: replacement_image={replacement_image_path}, device={device_name}, prompt={prompt_text}")
+
+      if not replacement_image_path:
+        return web.json_response({"error": "replacement_image_path is required"}, status=400)
+
+      # Create temp directory if it doesn't exist
+      tmp_dir = Path(tempfile.gettempdir())/"exo_animations"
+      tmp_dir.mkdir(parents=True, exist_ok=True)
+
+      # Generate unique output filename in temp directory
+      output_filename = f"animation_{uuid.uuid4()}.mp4"
+      output_path = str(tmp_dir/output_filename)
+
+      if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
+
+      # Create the animation
+      create_animation_mp4(
+        replacement_image_path,
+        output_path,
+        device_name,
+        prompt_text
+      )
+
+      return web.json_response({
+        "status": "success",
+        "output_path": output_path
+      })
+
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
+  async def handle_post_download(self, request):
+    try:
+      data = await request.json()
+      model_name = data.get("model")
+      if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
+      if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
+      shard = build_base_shard(model_name, self.inference_engine_classname)
+      if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
+      asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
+
+      return web.json_response({
+        "status": "success",
+        "message": f"Download started for model: {model_name}"
+      })
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"error": str(e)}, status=500)
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     runner = web.AppRunner(self.app)
     await runner.setup()
     await runner.setup()

+ 1 - 0
exo/apputil/__init__.py

@@ -0,0 +1 @@
+from exo.apputil.anim import create_animation_mp4

+ 161 - 0
exo/apputil/anim.py

@@ -0,0 +1,161 @@
+from PIL import Image, ImageDraw, ImageFont, ImageFilter
+import os
+import numpy as np
+import cv2
+
+def draw_rounded_rectangle(draw, coords, radius, fill):
+  left, top, right, bottom = coords
+  diameter = radius * 2
+  draw.rectangle([left + radius, top, right - radius, bottom], fill=fill)
+  draw.rectangle([left, top + radius, right, bottom - radius], fill=fill)
+  draw.pieslice([left, top, left + diameter, top + diameter], 180, 270, fill=fill)
+  draw.pieslice([right - diameter, top, right, top + diameter], 270, 360, fill=fill)
+  draw.pieslice([left, bottom - diameter, left + diameter, bottom], 90, 180, fill=fill)
+  draw.pieslice([right - diameter, bottom - diameter, right, bottom], 0, 90, fill=fill)
+
+def draw_centered_text_rounded(draw, text, font, rect_coords, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_width = bbox[2] - bbox[0]
+  text_height = bbox[3] - bbox[1]
+  rect_left, rect_top, rect_right, rect_bottom = rect_coords
+  rect_width = rect_right - rect_left
+  rect_height = rect_bottom - rect_top
+  text_x = rect_left + (rect_width - text_width) // 2
+  text_y = rect_top + (rect_height - text_height) // 2
+  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+
+def draw_left_aligned_text_rounded(draw, text, font, rect_coords, padding_left=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_height = bbox[3] - bbox[1]
+  rect_left, rect_top, rect_right, rect_bottom = rect_coords
+  rect_height = rect_bottom - rect_top
+  text_y = rect_top + (rect_height - text_height) // 2
+  text_x = rect_left + padding_left
+  draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+
+def draw_right_text_dynamic_width_rounded(draw, text, font, base_coords, padding=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
+  bbox = font.getbbox(text)
+  text_width = bbox[2] - bbox[0]
+  text_height = bbox[3] - bbox[1]
+  _, rect_top, rect_right, rect_bottom = base_coords
+  rect_height = rect_bottom - rect_top
+  new_rect_left = rect_right - (text_width + (padding * 2))
+  text_y = rect_top + (rect_height - text_height) // 2
+  text_x = new_rect_left + padding
+  draw_rounded_rectangle(draw, (new_rect_left, rect_top, rect_right, rect_bottom), radius, bg_color)
+  draw.text((text_x, text_y), text, fill=text_color, font=font)
+  return new_rect_left
+
+def draw_progress_bar(draw, progress, coords, color="yellow", bg_color=(70, 70, 70)):
+  left, top, right, bottom = coords
+  total_width = right - left
+  draw.rectangle(coords, fill=bg_color)
+  progress_width = int(total_width * progress)
+  if progress_width > 0:
+    draw.rectangle((left, top, left + progress_width, bottom), fill=color)
+
+def crop_image(image, top_crop=70):
+  width, height = image.size
+  return image.crop((0, top_crop, width, height))
+
+def create_animation_mp4(
+  replacement_image_path,
+  output_path,
+  device_name,
+  prompt_text,
+  fps=30,
+  target_size=(512, 512),
+  target_position=(139, 755),
+  progress_coords=(139, 1285, 655, 1295),
+  device_coords=(1240, 370, 1640, 416),
+  prompt_coords=(332, 1702, 2662, 1745)
+):
+  frames = []
+  try:
+    font = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 20)
+    promptfont = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 24)
+  except:
+    font = ImageFont.load_default()
+    promptfont = ImageFont.load_default()
+
+  # Process first frame
+  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
+  draw = ImageDraw.Draw(base_img)
+  draw_centered_text_rounded(draw, device_name, font, device_coords)
+  frames.extend([crop_image(base_img)] * 30)  # 1 second at 30fps
+
+  # Process second frame with typing animation
+  base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
+  for i in range(len(prompt_text) + 1):
+    current_frame = base_img2.copy()
+    draw = ImageDraw.Draw(current_frame)
+    draw_centered_text_rounded(draw, device_name, font, device_coords)
+    if i > 0:  # Only draw if we have at least one character
+      draw_left_aligned_text_rounded(draw, prompt_text[:i], promptfont, prompt_coords)
+    frames.extend([crop_image(current_frame)] * 2)  # 2 frames per character for smooth typing
+  
+  # Hold the complete prompt for a moment
+  frames.extend([frames[-1]] * 30)  # Hold for 1 second
+
+  # Create blur sequence
+  replacement_img = Image.open(replacement_image_path)
+  base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
+  blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
+
+  for i, blur_amount in enumerate(blur_steps):
+    new_frame = base_img.copy()
+    draw = ImageDraw.Draw(new_frame)
+
+    replacement_copy = replacement_img.copy()
+    replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
+    if blur_amount > 0:
+      replacement_copy = replacement_copy.filter(ImageFilter.GaussianBlur(radius=blur_amount))
+
+    mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
+    new_frame.paste(replacement_copy, target_position, mask)
+
+    draw_progress_bar(draw, (i + 1) / 9, progress_coords)
+    draw_centered_text_rounded(draw, device_name, font, device_coords)
+    draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
+
+    frames.extend([crop_image(new_frame)] * 15)  # 0.5 seconds at 30fps
+
+  # Create and add final frame (image4)
+  final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
+  draw = ImageDraw.Draw(final_base)
+
+  draw_centered_text_rounded(draw, device_name, font, device_coords)
+  draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
+
+  replacement_copy = replacement_img.copy()
+  replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
+  mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
+  final_base.paste(replacement_copy, target_position, mask)
+
+  frames.extend([crop_image(final_base)] * 30)  # 1 second at 30fps
+
+  # Convert frames to video using H.264 codec
+  if frames:
+    first_frame = np.array(frames[0])
+    height, width = first_frame.shape[:2]
+    fourcc = cv2.VideoWriter_fourcc(*'avc1')
+    out = cv2.VideoWriter(
+      output_path,
+      fourcc,
+      fps,
+      (width, height),
+      isColor=True
+    )
+
+    if not out.isOpened():
+      print("Error: VideoWriter failed to open")
+      return
+
+    for frame in frames:
+      frame_array = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
+      out.write(frame_array)
+    
+    out.release()
+    print(f"Video saved successfully to {output_path}")

+ 3 - 0
exo/apputil/baseimages/image1.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:361fdadd67c277d45cd18b0bfc8c5ceea5fd89f2d65aef157fd915ce9cbb8599
+size 814460

+ 3 - 0
exo/apputil/baseimages/image2.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0e3891bc6b4f4dfa7444af53fcaa4b3ba06b0549546202be3243f08a0e6bd7e
+size 814235

+ 3 - 0
exo/apputil/baseimages/image3.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2dc5b3378aef397d60fd1252da8a1c578ad97e202a859590ffa416b49551d19
+size 146633

+ 3 - 0
exo/apputil/baseimages/image4.png

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dbc6883e2a3c5233ec7b844c98646922bdc4f5e42e1f424857eaff56f785dbcd
+size 668550

+ 1 - 1
exo/inference/dummy_inference_engine.py

@@ -18,7 +18,7 @@ class DummyInferenceEngine(InferenceEngine):
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     return np.array(self.tokenizer.encode(prompt))
     return np.array(self.tokenizer.encode(prompt))
   
   
-  async def sample(self, x: np.ndarray) -> np.ndarray:
+  async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
     if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
     if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
     return x
     return x
 
 

+ 3 - 2
exo/inference/mlx/sharded_inference_engine.py

@@ -1,15 +1,16 @@
 import numpy as np
 import numpy as np
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
+from mlx_lm.sample_utils import top_p_sampling
 from ..inference_engine import InferenceEngine
 from ..inference_engine import InferenceEngine
 from .stateful_model import StatefulModel
 from .stateful_model import StatefulModel
-from .sharded_utils import load_shard, get_image_from_str
+from .sharded_utils import load_shard
 from ..shard import Shard
 from ..shard import Shard
 from typing import Dict, Optional, Tuple
 from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 import asyncio
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
-from functools import partial
+
 def sample_logits(
 def sample_logits(
   logits: mx.array,
   logits: mx.array,
   temp: float = 0.0,
   temp: float = 0.0,

+ 6 - 1
exo/inference/tinygrad/tinygrad_helpers.py

@@ -7,6 +7,7 @@ from exo.inference.shard import Shard
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
 from exo.download.hf.hf_helpers import get_allow_patterns
 from exo.download.hf.hf_helpers import get_allow_patterns
 from fnmatch import fnmatch
 from fnmatch import fnmatch
+import re
 
 
 
 
 # **** helper functions ****
 # **** helper functions ****
@@ -42,6 +43,10 @@ def load(fn: str, shard: Shard):
     if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}
   elif fn.endswith(".safetensors"):
   elif fn.endswith(".safetensors"):
-    return safe_load(fn)
+    weight_map = safe_load(fn)
+    for k in list(weight_map):
+      if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
+          del weight_map[k]
+    return weight_map
   else:
   else:
     return torch_load(fn)
     return torch_load(fn)

+ 11 - 3
exo/main.py

@@ -55,8 +55,10 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
+parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
+parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
 args = parser.parse_args()
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 print(f"Selected inference engine: {args.inference_engine}")
 
 
@@ -88,6 +90,9 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
 
+# Convert node-id-filter to list if provided
+allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
+
 if args.discovery_module == "udp":
 if args.discovery_module == "udp":
   discovery = UDPDiscovery(
   discovery = UDPDiscovery(
     args.node_id,
     args.node_id,
@@ -95,7 +100,8 @@ if args.discovery_module == "udp":
     args.listen_port,
     args.listen_port,
     args.broadcast_port,
     args.broadcast_port,
     lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
     lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
-    discovery_timeout=args.discovery_timeout
+    discovery_timeout=args.discovery_timeout,
+    allowed_node_ids=allowed_node_ids
   )
   )
 elif args.discovery_module == "tailscale":
 elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(
   discovery = TailscaleDiscovery(
@@ -104,7 +110,8 @@ elif args.discovery_module == "tailscale":
     lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
     lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
     discovery_timeout=args.discovery_timeout,
     discovery_timeout=args.discovery_timeout,
     tailscale_api_key=args.tailscale_api_key,
     tailscale_api_key=args.tailscale_api_key,
-    tailnet=args.tailnet_name
+    tailnet=args.tailnet_name,
+    allowed_node_ids=allowed_node_ids
   )
   )
 elif args.discovery_module == "manual":
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
   if not args.discovery_config_path:
@@ -119,7 +126,8 @@ node = StandardNode(
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   max_generate_tokens=args.max_generate_tokens,
   topology_viz=topology_viz,
   topology_viz=topology_viz,
-  shard_downloader=shard_downloader
+  shard_downloader=shard_downloader,
+  default_sample_temperature=args.default_temp
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server

+ 6 - 0
exo/networking/tailscale/tailscale_discovery.py

@@ -21,6 +21,7 @@ class TailscaleDiscovery(Discovery):
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     tailscale_api_key: str = None,
     tailscale_api_key: str = None,
     tailnet: str = None,
     tailnet: str = None,
+    allowed_node_ids: List[str] = None,
   ):
   ):
     self.node_id = node_id
     self.node_id = node_id
     self.node_port = node_port
     self.node_port = node_port
@@ -34,6 +35,7 @@ class TailscaleDiscovery(Discovery):
     self.cleanup_task = None
     self.cleanup_task = None
     self.tailscale_api_key = tailscale_api_key
     self.tailscale_api_key = tailscale_api_key
     self.tailnet = tailnet
     self.tailnet = tailnet
+    self.allowed_node_ids = allowed_node_ids
     self._device_id = None
     self._device_id = None
     self.update_task = None
     self.update_task = None
 
 
@@ -84,6 +86,10 @@ class TailscaleDiscovery(Discovery):
             if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
             if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
             continue
             continue
 
 
+          if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
+            if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
+            continue
+
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
           if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
             new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
             new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
             if not await new_peer_handle.health_check():
             if not await new_peer_handle.health_check():

+ 8 - 0
exo/networking/udp/udp_discovery.py

@@ -45,6 +45,7 @@ class UDPDiscovery(Discovery):
     broadcast_interval: int = 1,
     broadcast_interval: int = 1,
     discovery_timeout: int = 30,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    allowed_node_ids: List[str] = None,
   ):
   ):
     self.node_id = node_id
     self.node_id = node_id
     self.node_port = node_port
     self.node_port = node_port
@@ -54,6 +55,7 @@ class UDPDiscovery(Discovery):
     self.broadcast_interval = broadcast_interval
     self.broadcast_interval = broadcast_interval
     self.discovery_timeout = discovery_timeout
     self.discovery_timeout = discovery_timeout
     self.device_capabilities = device_capabilities
     self.device_capabilities = device_capabilities
+    self.allowed_node_ids = allowed_node_ids
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
     self.broadcast_task = None
     self.broadcast_task = None
     self.listen_task = None
     self.listen_task = None
@@ -133,6 +135,12 @@ class UDPDiscovery(Discovery):
 
 
     if message["type"] == "discovery" and message["node_id"] != self.node_id:
     if message["type"] == "discovery" and message["node_id"] != self.node_id:
       peer_id = message["node_id"]
       peer_id = message["node_id"]
+      
+      # Skip if peer_id is not in allowed list
+      if self.allowed_node_ids and peer_id not in self.allowed_node_ids:
+        if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as it's not in the allowed node IDs list")
+        return
+
       peer_host = addr[0]
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_port = message["grpc_port"]
       peer_prio = message["priority"]
       peer_prio = message["priority"]

+ 3 - 1
exo/orchestration/standard_node.py

@@ -27,6 +27,7 @@ class StandardNode(Node):
     discovery: Discovery,
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
     max_generate_tokens: int = 1024,
+    default_sample_temperature: float = 0.0,
     topology_viz: Optional[TopologyViz] = None,
     topology_viz: Optional[TopologyViz] = None,
     shard_downloader: Optional[HFShardDownloader] = None,
     shard_downloader: Optional[HFShardDownloader] = None,
   ):
   ):
@@ -43,6 +44,7 @@ class StandardNode(Node):
     self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self.topology_viz = topology_viz
+    self.default_sample_temperature = default_sample_temperature
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
@@ -114,7 +116,7 @@ class StandardNode(Node):
         self.buffered_token_output[request_id] = ([], False)
         self.buffered_token_output[request_id] = ([], False)
       is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
       is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
       if shard.is_last_layer() and not is_finished:
       if shard.is_last_layer() and not is_finished:
-        token = await self.inference_engine.sample(result)
+        token = await self.inference_engine.sample(result, temp = self.default_sample_temperature)
         self.buffered_token_output[request_id][0].append(token.item())
         self.buffered_token_output[request_id][0].append(token.item())
         intermediate_result = self.buffered_token_output[request_id][0]
         intermediate_result = self.buffered_token_output[request_id][0]
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")

+ 138 - 56
extra/dashboard/dashboard.py

@@ -10,7 +10,7 @@ from pathlib import Path
 from plotly.subplots import make_subplots
 from plotly.subplots import make_subplots
 import plotly.graph_objects as go
 import plotly.graph_objects as go
 import time
 import time
-import simpleaudio as sa
+import pygame.mixer
 from datetime import datetime
 from datetime import datetime
 
 
 class AsyncCircleCIClient:
 class AsyncCircleCIClient:
@@ -39,15 +39,7 @@ class AsyncCircleCIClient:
     ):
     ):
         """
         """
         Get recent pipelines for a project with pagination support
         Get recent pipelines for a project with pagination support
-
-        Args:
-            session: aiohttp client session
-            org_slug: Organization slug
-            page_token: Token for pagination
-            limit: Maximum number of pipelines to return
-            branch: Specific branch to fetch pipelines from
         """
         """
-
         params = {
         params = {
             "branch": branch,
             "branch": branch,
             "page-token": page_token
             "page-token": page_token
@@ -62,13 +54,17 @@ class AsyncCircleCIClient:
 
 
         next_page_token = data.get("next_page_token")
         next_page_token = data.get("next_page_token")
 
 
+        # If we have a limit, check if we need more pages
+        if limit and len(pipelines) >= limit:
+            return pipelines
+
         # If there are more pages and we haven't hit the limit, recursively get them
         # If there are more pages and we haven't hit the limit, recursively get them
-        if next_page_token and (limit is None or len(pipelines) < limit):
+        if next_page_token:
             next_pipelines = await self.get_recent_pipelines(
             next_pipelines = await self.get_recent_pipelines(
                 session,
                 session,
                 org_slug,
                 org_slug,
                 page_token=next_page_token,
                 page_token=next_page_token,
-                limit=limit,
+                limit=limit - len(pipelines) if limit else None,  # Adjust limit for next page
                 branch=branch
                 branch=branch
             )
             )
             pipelines.extend(next_pipelines)
             pipelines.extend(next_pipelines)
@@ -108,20 +104,30 @@ class PackageSizeTracker:
         self.client = AsyncCircleCIClient(token, project_slug)
         self.client = AsyncCircleCIClient(token, project_slug)
         self.logger = logging.getLogger("PackageSizeTracker")
         self.logger = logging.getLogger("PackageSizeTracker")
         self.last_data_hash = None
         self.last_data_hash = None
-        self.notification_sound_path = Path(__file__).parent / "notification.wav"
         self.debug = debug
         self.debug = debug
 
 
-        # Sound file paths - replace these with your actual sound files
+        # Initialize pygame mixer
+        pygame.mixer.init()
+
+        # Sound file paths - can use MP3 files with pygame
         sounds_dir = Path(__file__).parent / "sounds"
         sounds_dir = Path(__file__).parent / "sounds"
         self.sounds = {
         self.sounds = {
-            'lines_up': sounds_dir / "lines_increased.wav",
-            'lines_down': sounds_dir / "lines_decreased.wav",
-            'tokens_up': sounds_dir / "tokens_increased.wav",
-            'tokens_down': sounds_dir / "tokens_decreased.wav",
-            'size_up': sounds_dir / "size_increased.wav",
-            'size_down': sounds_dir / "size_decreased.wav"
+            'lines_up': sounds_dir / "gta5_wasted.mp3",
+            'lines_down': sounds_dir / "pokemon_evolve.mp3",
+            'tokens_up': sounds_dir / "pokemon_evolve.mp3",
+            'tokens_down': sounds_dir / "gta5_wasted.mp3",
+            'size_up': sounds_dir / "gta5_wasted.mp3",
+            'size_down': sounds_dir / "pokemon_evolve.mp3"
         }
         }
 
 
+    def test_sound_effects(self):
+        """Test all sound effects with a small delay between each"""
+        self.logger.info("Testing sound effects...")
+        for sound_key in self.sounds:
+            self.logger.info(f"Playing {sound_key}")
+            self._play_sound(sound_key)
+            time.sleep(1)  # Wait 1 second between sounds
+
     def setup_logging(self, debug: bool):
     def setup_logging(self, debug: bool):
         level = logging.DEBUG if debug else logging.INFO
         level = logging.DEBUG if debug else logging.INFO
         logging.basicConfig(
         logging.basicConfig(
@@ -274,16 +280,57 @@ class PackageSizeTracker:
             self.logger.error(f"Error processing pipeline {pipeline['id']}: {str(e)}")
             self.logger.error(f"Error processing pipeline {pipeline['id']}: {str(e)}")
             return None
             return None
 
 
+    async def process_pipeline_batch(
+        self,
+        session: aiohttp.ClientSession,
+        pipelines: List[Dict],
+        batch_size: int = 5
+    ) -> List[Dict]:
+        """
+        Process a batch of pipelines with rate limiting.
+
+        Args:
+            session: aiohttp client session
+            pipelines: List of pipelines to process
+            batch_size: Number of pipelines to process in parallel
+
+        Returns:
+            List of processed pipeline data points
+        """
+        data_points = []
+
+        for i in range(0, len(pipelines), batch_size):
+            batch = pipelines[i:i + batch_size]
+
+            # Process batch in parallel
+            tasks = [self.process_pipeline(session, pipeline) for pipeline in batch]
+            batch_results = await asyncio.gather(*tasks)
+
+            # Filter out None results
+            batch_data = [r for r in batch_results if r is not None]
+            data_points.extend(batch_data)
+
+            # Add delay between batches if there are more to process
+            if i + batch_size < len(pipelines):
+                await asyncio.sleep(1)  # 1 second delay between batches
+
+        return data_points
+
     async def collect_data(self) -> List[Dict]:
     async def collect_data(self) -> List[Dict]:
         self.logger.info("Starting data collection...")
         self.logger.info("Starting data collection...")
         async with aiohttp.ClientSession(headers=self.client.headers) as session:
         async with aiohttp.ClientSession(headers=self.client.headers) as session:
-            # Get pipelines from both main and circleci branches
+            # Get pipelines from main branch
             main_pipelines = await self.client.get_recent_pipelines(
             main_pipelines = await self.client.get_recent_pipelines(
                 session,
                 session,
                 org_slug=self.client.project_slug,
                 org_slug=self.client.project_slug,
                 limit=20,
                 limit=20,
                 branch="main"
                 branch="main"
             )
             )
+
+            # Add delay between branch requests
+            await asyncio.sleep(2)
+
+            # Get pipelines from circleci branch
             circleci_pipelines = await self.client.get_recent_pipelines(
             circleci_pipelines = await self.client.get_recent_pipelines(
                 session,
                 session,
                 org_slug=self.client.project_slug,
                 org_slug=self.client.project_slug,
@@ -291,18 +338,27 @@ class PackageSizeTracker:
                 branch="circleci"
                 branch="circleci"
             )
             )
 
 
+            # Combine pipelines and sort by created_at date
             pipelines = main_pipelines + circleci_pipelines
             pipelines = main_pipelines + circleci_pipelines
-            # Sort pipelines by created_at date
-            pipelines.sort(key=lambda x: x.get("created_at", x.get("updated_at")), reverse=True)
+            pipelines.sort(
+                key=lambda x: datetime.fromisoformat(
+                    x.get("created_at", x.get("updated_at")).replace('Z', '+00:00')
+                ),
+                reverse=True  # Most recent first
+            )
 
 
             self.logger.info(f"Found {len(pipelines)} recent pipelines")
             self.logger.info(f"Found {len(pipelines)} recent pipelines")
 
 
-            # Process all pipelines in parallel
-            tasks = [self.process_pipeline(session, pipeline) for pipeline in pipelines]
-            results = await asyncio.gather(*tasks)
+            # Process pipelines in batches
+            data_points = await self.process_pipeline_batch(session, pipelines)
 
 
-            # Filter out None results
-            data_points = [r for r in results if r is not None]
+            # Sort by timestamp
+            data_points.sort(
+                key=lambda x: datetime.fromisoformat(
+                    x.get("timestamp").replace('Z', '+00:00')
+                ),
+                reverse=True  # Most recent first
+            )
 
 
         return data_points
         return data_points
 
 
@@ -931,60 +987,86 @@ class PackageSizeTracker:
         ])))
         ])))
 
 
     def _play_sound(self, sound_key: str):
     def _play_sound(self, sound_key: str):
-        """Play a specific notification sound"""
+        """Play a specific notification sound using pygame"""
         try:
         try:
             sound_path = self.sounds.get(sound_key)
             sound_path = self.sounds.get(sound_key)
             if sound_path and sound_path.exists():
             if sound_path and sound_path.exists():
-                wave_obj = sa.WaveObject.from_wave_file(str(sound_path))
-                wave_obj.play()
+                sound = pygame.mixer.Sound(str(sound_path))
+                sound.play()
+                # Wait for the sound to finish playing
+                pygame.time.wait(int(sound.get_length() * 1000))
             else:
             else:
                 self.logger.warning(f"Sound file not found: {sound_key} at {sound_path}")
                 self.logger.warning(f"Sound file not found: {sound_key} at {sound_path}")
         except Exception as e:
         except Exception as e:
             self.logger.error(f"Failed to play sound {sound_key}: {e}")
             self.logger.error(f"Failed to play sound {sound_key}: {e}")
 
 
     def _check_metrics_changes(self, current_data: List[Dict], previous_data: List[Dict]):
     def _check_metrics_changes(self, current_data: List[Dict], previous_data: List[Dict]):
-        """Check for specific metric changes and play appropriate sounds"""
-        if not previous_data:
-            return
+        # Sort data by timestamp in descending order (most recent first)
+        def sort_by_timestamp(data):
+            return sorted(
+                data,
+                key=lambda x: x.get('timestamp', ''),
+                reverse=True  # Most recent first
+            )
+
+        current_data = sort_by_timestamp(current_data)
+        previous_data = sort_by_timestamp(previous_data)
 
 
-        # Get latest data points
-        current = current_data[-1]
-        previous = previous_data[-1]
+        # Helper to find latest entry with a specific metric
+        def find_latest_with_metric(data: List[Dict], metric: str) -> Optional[Dict]:
+            return next((d for d in data if metric in d), None)
 
 
         # Check line count changes
         # Check line count changes
-        if 'total_lines' in current and 'total_lines' in previous:
-            diff = current['total_lines'] - previous['total_lines']
+        current_lines = find_latest_with_metric(current_data, 'total_lines')
+        previous_lines = find_latest_with_metric(previous_data, 'total_lines')
+
+        if current_lines and previous_lines:
+            diff = current_lines['total_lines'] - previous_lines['total_lines']
+            self.logger.debug(f"Lines of code diff: {diff}")
             if diff > 0:
             if diff > 0:
                 self.logger.info(f"Lines of code increased by {diff:,}")
                 self.logger.info(f"Lines of code increased by {diff:,}")
                 self._play_sound('lines_up')
                 self._play_sound('lines_up')
             elif diff < 0:
             elif diff < 0:
                 self.logger.info(f"Lines of code decreased by {abs(diff):,}")
                 self.logger.info(f"Lines of code decreased by {abs(diff):,}")
                 self._play_sound('lines_down')
                 self._play_sound('lines_down')
+        else:
+            self.logger.debug("No lines of code data found")
 
 
         # Check tokens per second changes
         # Check tokens per second changes
-        if 'tokens_per_second' in current and 'tokens_per_second' in previous:
-            diff = current['tokens_per_second'] - previous['tokens_per_second']
+        current_tokens = find_latest_with_metric(current_data, 'tokens_per_second')
+        previous_tokens = find_latest_with_metric(previous_data, 'tokens_per_second')
+
+        if current_tokens and previous_tokens:
+            diff = current_tokens['tokens_per_second'] - previous_tokens['tokens_per_second']
+            self.logger.debug(f"Tokens per second diff: {diff}")
             if diff > 0:
             if diff > 0:
                 self.logger.info(f"Tokens per second increased by {diff:.2f}")
                 self.logger.info(f"Tokens per second increased by {diff:.2f}")
                 self._play_sound('tokens_up')
                 self._play_sound('tokens_up')
             elif diff < 0:
             elif diff < 0:
                 self.logger.info(f"Tokens per second decreased by {abs(diff):.2f}")
                 self.logger.info(f"Tokens per second decreased by {abs(diff):.2f}")
                 self._play_sound('tokens_down')
                 self._play_sound('tokens_down')
+        else:
+            self.logger.debug("No tokens per second data found")
 
 
         # Check package size changes
         # Check package size changes
-        if 'total_size_mb' in current and 'total_size_mb' in previous:
-            diff = current['total_size_mb'] - previous['total_size_mb']
+        current_size = find_latest_with_metric(current_data, 'total_size_mb')
+        previous_size = find_latest_with_metric(previous_data, 'total_size_mb')
+
+        if current_size and previous_size:
+            diff = current_size['total_size_mb'] - previous_size['total_size_mb']
+            self.logger.debug(f"Package size diff: {diff:.2f}MB")
             if diff > 0:
             if diff > 0:
                 self.logger.info(f"Package size increased by {diff:.2f}MB")
                 self.logger.info(f"Package size increased by {diff:.2f}MB")
                 self._play_sound('size_up')
                 self._play_sound('size_up')
             elif diff < 0:
             elif diff < 0:
                 self.logger.info(f"Package size decreased by {abs(diff):.2f}MB")
                 self.logger.info(f"Package size decreased by {abs(diff):.2f}MB")
                 self._play_sound('size_down')
                 self._play_sound('size_down')
+        else:
+            self.logger.debug("No package size data found")
 
 
-    async def run_dashboard(self, update_interval: int = 30):
+    async def run_dashboard(self, update_interval: int = 10):
         """Run the dashboard with periodic updates"""
         """Run the dashboard with periodic updates"""
         try:
         try:
-            # Force convert update_interval to float and log its type
             update_interval = float(update_interval)
             update_interval = float(update_interval)
             self.logger.debug(f"Update interval type: {type(update_interval)}, value: {update_interval}")
             self.logger.debug(f"Update interval type: {type(update_interval)}, value: {update_interval}")
         except ValueError as e:
         except ValueError as e:
@@ -997,7 +1079,6 @@ class PackageSizeTracker:
         while True:
         while True:
             try:
             try:
                 start_time = time.time()
                 start_time = time.time()
-                self.logger.debug(f"Start time type: {type(start_time)}, value: {start_time}")
 
 
                 # Collect new data
                 # Collect new data
                 current_data = await self.collect_data()
                 current_data = await self.collect_data()
@@ -1013,25 +1094,26 @@ class PackageSizeTracker:
                         f"Dashboard updated at {datetime.now().strftime('%H:%M:%S')}"
                         f"Dashboard updated at {datetime.now().strftime('%H:%M:%S')}"
                     )
                     )
 
 
-                    # Check for metric changes and play appropriate sounds
-                    self._check_metrics_changes(current_data, previous_data)
+                    print("Curr:", len(current_data))
+                    print("Prev:", len(previous_data) if previous_data else "None")
+                    if previous_data:
+                        # Check for metric changes and play appropriate sounds
+                        self.logger.debug(f"Checking metrics changes between {len(current_data)} current and {len(previous_data)} previous data points")
+                        self._check_metrics_changes(current_data, previous_data)
 
 
                 # Update previous data
                 # Update previous data
-                previous_data = current_data
+                previous_data = current_data.copy()  # Make a copy to prevent reference issues
 
 
-                # Calculate sleep time with explicit type conversion and logging
+                # Calculate sleep time
                 elapsed = float(time.time() - start_time)
                 elapsed = float(time.time() - start_time)
-                self.logger.debug(f"Elapsed time type: {type(elapsed)}, value: {elapsed}")
-                sleep_time = max(0.0, float(update_interval) - elapsed)
-                self.logger.debug(f"Sleep time type: {type(sleep_time)}, value: {sleep_time}")
-
+                sleep_time = max(0.0, update_interval - elapsed)
                 await asyncio.sleep(sleep_time)
                 await asyncio.sleep(sleep_time)
 
 
             except Exception as e:
             except Exception as e:
                 self.logger.error(f"Error in dashboard update loop: {e}", exc_info=True)
                 self.logger.error(f"Error in dashboard update loop: {e}", exc_info=True)
                 if self.debug:
                 if self.debug:
                     raise
                     raise
-                await asyncio.sleep(float(update_interval))
+                await asyncio.sleep(update_interval)
 
 
 async def main():
 async def main():
     token = os.getenv("CIRCLECI_TOKEN")
     token = os.getenv("CIRCLECI_TOKEN")
@@ -1040,11 +1122,11 @@ async def main():
 
 
     try:
     try:
         # Get update interval from environment or use default
         # Get update interval from environment or use default
-        update_interval = float(os.getenv("UPDATE_INTERVAL", "30"))
+        update_interval = float(os.getenv("UPDATE_INTERVAL", "10"))
         print(f"Update interval type: {type(update_interval)}, value: {update_interval}")  # Debug print
         print(f"Update interval type: {type(update_interval)}, value: {update_interval}")  # Debug print
     except ValueError as e:
     except ValueError as e:
         print(f"Error converting UPDATE_INTERVAL to float: {os.getenv('UPDATE_INTERVAL')}")
         print(f"Error converting UPDATE_INTERVAL to float: {os.getenv('UPDATE_INTERVAL')}")
-        update_interval = 30.0
+        update_interval = 10.0
 
 
     if not token or not project_slug:
     if not token or not project_slug:
         print("Error: Please set CIRCLECI_TOKEN and CIRCLECI_PROJECT_SLUG environment variables")
         print("Error: Please set CIRCLECI_TOKEN and CIRCLECI_PROJECT_SLUG environment variables")

+ 1 - 1
extra/dashboard/requirements.txt

@@ -2,4 +2,4 @@ plotly
 pandas
 pandas
 requests
 requests
 aiohttp
 aiohttp
-simpleaudio
+pygame

+ 3 - 0
extra/dashboard/sounds/gta5_wasted.mp3

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb3fb66dd02827fbff86ef1ce3bc6438371c823aed7d4c3803ed522f008e4947
+size 206399

+ 3 - 0
extra/dashboard/sounds/pokemon_evolve.mp3

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d99cc9bdab4a4639d50f439b424547000e7c79f195b5b121734ad4ead435911c
+size 633345

+ 7 - 1
install.sh

@@ -1,5 +1,11 @@
 #!/bin/bash
 #!/bin/bash
 
 
-python3 -m venv .venv
+if command -v python3.12 &>/dev/null; then
+    echo "Python 3.12 is installed, proceeding with python3.12..."
+    python3.12 -m venv .venv
+else
+    echo "The recommended version of Python to run exo with is Python 3.12, but $(python3 --version) is installed. Proceeding with $(python3 --version)"
+    python3 -m venv .venv
+fi
 source .venv/bin/activate
 source .venv/bin/activate
 pip install -e .
 pip install -e .

+ 1 - 0
setup.py

@@ -15,6 +15,7 @@ install_requires = [
   "numpy==2.0.0",
   "numpy==2.0.0",
   "nuitka==2.5.1",
   "nuitka==2.5.1",
   "nvidia-ml-py==12.560.30",
   "nvidia-ml-py==12.560.30",
+  "opencv-python==4.10.0.84",
   "pillow==10.4.0",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
   "prometheus-client==0.20.0",
   "protobuf==5.28.1",
   "protobuf==5.28.1",