Selaa lähdekoodia

Merge pull request #617 from exo-explore/runners2

Lots of fixes and QoL improvements.
Alex Cheema 6 kuukautta sitten
vanhempi
commit
2644fd02c8
39 muutettua tiedostoa jossa 2117 lisäystä ja 542 poistoa
  1. 28 0
      .circleci/config.yml
  2. 401 0
      .github/bench.py
  3. 330 0
      .github/bootstrap.sh
  4. 95 0
      .github/optimize_performance.sh
  5. 206 0
      .github/workflows/bench_job.yml
  6. 71 0
      .github/workflows/benchmarks.yml
  7. 32 7
      configure_mlx.sh
  8. 75 71
      exo/api/chatgpt_api.py
  9. 1 1
      exo/download/hf/hf_helpers.py
  10. 2 1
      exo/download/hf/hf_shard_download.py
  11. 35 9
      exo/helpers.py
  12. 7 0
      exo/inference/mlx/perf_improvements.md
  13. 80 68
      exo/inference/mlx/sharded_inference_engine.py
  14. 81 0
      exo/inference/mlx/test_non_blocking.py
  15. 79 29
      exo/main.py
  16. 24 18
      exo/networking/grpc/grpc_peer_handle.py
  17. 11 3
      exo/networking/grpc/grpc_server.py
  18. 0 10
      exo/networking/grpc/node_service.proto
  19. 2 2
      exo/networking/grpc/node_service_pb2.py
  20. 44 87
      exo/networking/grpc/node_service_pb2_grpc.py
  21. 1 2
      exo/networking/manual/manual_discovery.py
  22. 0 4
      exo/networking/peer_handle.py
  23. 1 1
      exo/networking/tailscale/tailscale_discovery.py
  24. 42 14
      exo/networking/udp/udp_discovery.py
  25. 19 51
      exo/orchestration/node.py
  26. 9 0
      exo/orchestration/test_node.py
  27. 166 0
      exo/orchestration/tracing.py
  28. 0 0
      exo/stats/__init__.py
  29. 0 27
      exo/stats/docker-compose-stats.yml
  30. 0 29
      exo/stats/metrics.py
  31. 0 7
      exo/stats/prometheus.yml
  32. 88 0
      exo/tinychat/index.css
  33. 37 22
      exo/tinychat/index.html
  34. 78 10
      exo/tinychat/index.js
  35. 16 23
      exo/topology/device_capabilities.py
  36. 3 1
      exo/topology/partitioning_strategy.py
  37. 41 38
      exo/topology/test_device_capabilities.py
  38. 3 3
      extra/line_counter.py
  39. 9 4
      setup.py

+ 28 - 0
.circleci/config.yml

@@ -254,6 +254,33 @@ jobs:
           prompt: "Keep responses concise. Who was the king of pop?"
           expected_output: "Michael Jackson"
 
+  chatgpt_api_integration_test_tinygrad_linux:
+    machine:
+      image: ubuntu-2204:current
+    resource_class: xlarge
+    steps:
+      - checkout
+      - run:
+          name: Set up Python
+          command: |
+            sudo apt-get update
+            sudo add-apt-repository -y ppa:deadsnakes/ppa
+            sudo apt-get update
+            sudo apt-get install -y python3.12 python3.12-venv clang
+            python3.12 -m venv env
+            source env/bin/activate
+      - run:
+          name: Install dependencies
+          command: |
+            source env/bin/activate
+            pip install --upgrade pip
+            pip install .
+      - run_chatgpt_api_test:
+          inference_engine: tinygrad
+          model_id: llama-3.2-1b
+          prompt: "Keep responses concise. Who was the king of pop?"
+          expected_output: "Michael Jackson"
+
   measure_pip_sizes:
     macos:
       xcode: "16.0.0"
@@ -342,5 +369,6 @@ workflows:
       - discovery_integration_test
       - chatgpt_api_integration_test_mlx
       - chatgpt_api_integration_test_tinygrad
+      - chatgpt_api_integration_test_tinygrad_linux
       - chatgpt_api_integration_test_dummy
       - measure_pip_sizes

+ 401 - 0
.github/bench.py

@@ -0,0 +1,401 @@
+import aiohttp
+import asyncio
+import time
+import json
+import os
+import boto3
+from typing import Dict, Any
+from datetime import datetime
+import subprocess
+import psutil
+import platform
+from pathlib import Path
+
+
+def check_system_state():
+    print("\n=== System State Check ===", flush=True)
+    
+    # Add macOS-specific checks
+    try:
+        # Check powermetrics with sudo
+        try:
+            power_metrics = subprocess.run(
+                ['sudo', 'powermetrics', '-n', '1', '-i', '1000', '--samplers', 'cpu_power'],
+                capture_output=True, text=True
+            )
+            print("\nPower Metrics:", power_metrics.stdout, flush=True)
+        except Exception as e:
+            print(f"Error getting power metrics: {e}", flush=True)
+        
+        # Check thermal state
+        thermal_state = subprocess.run(['pmset', '-g', 'therm'], capture_output=True, text=True)
+        print("\nThermal State:", thermal_state.stdout, flush=True)
+        
+        # Check if running under Rosetta
+        arch = subprocess.run(['arch'], capture_output=True, text=True)
+        print("\nArchitecture:", arch.stdout, flush=True)
+        
+        # Check MLX compilation mode - only if mlx is available
+        try:
+            import mlx.core as mx
+            if hasattr(mx, 'build_info'):
+                print("\nMLX Build Info:", mx.build_info(), flush=True)
+            else:
+                print("\nMLX Build Info: Not available in this version", flush=True)
+        except ImportError:
+            print("\nMLX: Not installed", flush=True)
+        except Exception as e:
+            print(f"\nError checking MLX: {e}", flush=True)
+        
+    except Exception as e:
+        print(f"Error in macOS checks: {e}", flush=True)
+
+    # CPU Info
+    print("\nCPU Information:", flush=True)
+    try:
+        if platform.system() == 'Darwin' and platform.processor() == 'arm':
+            # Use sysctl for Apple Silicon Macs
+            cpu_info = subprocess.run(['sysctl', 'machdep.cpu'], capture_output=True, text=True)
+            if cpu_info.returncode == 0:
+                print(f"CPU Info (Apple Silicon):", cpu_info.stdout, flush=True)
+            
+            # Parse powermetrics output for clearer CPU frequency display
+            try:
+                power_metrics = subprocess.run(
+                    ['sudo', 'powermetrics', '-n', '1', '-i', '100', '--samplers', 'cpu_power'],
+                    capture_output=True, text=True
+                )
+                if power_metrics.returncode == 0:
+                    output = power_metrics.stdout
+                    print("\nDetailed CPU Frequency Information:")
+                    
+                    # Extract cluster frequencies and max frequencies
+                    current_cluster = None
+                    max_freqs = {'E': 0, 'P0': 0, 'P1': 0}
+                    
+                    for line in output.split('\n'):
+                        # Track which cluster we're processing
+                        if "E-Cluster" in line:
+                            current_cluster = 'E'
+                        elif "P0-Cluster" in line:
+                            current_cluster = 'P0'
+                        elif "P1-Cluster" in line:
+                            current_cluster = 'P1'
+                            
+                        # Get current frequencies
+                        if "HW active frequency:" in line:
+                            freq = line.split(':')[1].strip()
+                            if freq != "0 MHz":
+                                print(f"Current {current_cluster}-Cluster Frequency: {freq}")
+                        
+                        # Get max frequencies from residency lines
+                        if current_cluster and "active residency:" in line and "MHz:" in line:
+                            try:
+                                # Extract all frequency values
+                                freqs = []
+                                parts = line.split('MHz:')[:-1]  # Skip last part as it's not a frequency
+                                for part in parts:
+                                    freq_str = part.split()[-1]
+                                    try:
+                                        freq = float(freq_str)
+                                        freqs.append(freq)
+                                    except ValueError:
+                                        continue
+                                if freqs:
+                                    max_freqs[current_cluster] = max(max_freqs[current_cluster], max(freqs))
+                            except Exception:
+                                continue
+                    
+                    # Print max frequencies
+                    print("\nMaximum Available Frequencies:")
+                    for cluster, max_freq in max_freqs.items():
+                        if max_freq > 0:
+                            print(f"{cluster}-Cluster Max: {max_freq:.0f} MHz")
+                            
+            except Exception as e:
+                print(f"Error parsing powermetrics: {e}", flush=True)
+        else:
+            # Use psutil for other systems
+            cpu_freq = psutil.cpu_freq()
+            print(f"CPU Frequency - Current: {cpu_freq.current:.2f}MHz, Min: {cpu_freq.min:.2f}MHz, Max: {cpu_freq.max:.2f}MHz", flush=True)
+        
+        print(f"\nCPU Usage per Core: {psutil.cpu_percent(percpu=True)}%", flush=True)
+        
+        # Check if running in low power mode
+        power_mode = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
+        print("\nPower Settings:", power_mode.stdout, flush=True)
+    except Exception as e:
+        print(f"Error getting CPU info: {e}", flush=True)
+
+    # Memory Info
+    print("\nMemory Information:", flush=True)
+    try:
+        mem = psutil.virtual_memory()
+        print(f"Total: {mem.total/1024/1024/1024:.2f}GB", flush=True)
+        print(f"Available: {mem.available/1024/1024/1024:.2f}GB", flush=True)
+        print(f"Used: {mem.used/1024/1024/1024:.2f}GB ({mem.percent}%)", flush=True)
+        
+        # Check swap
+        swap = psutil.swap_memory()
+        print(f"Swap Used: {swap.used/1024/1024/1024:.2f}GB of {swap.total/1024/1024/1024:.2f}GB", flush=True)
+    except Exception as e:
+        print(f"Error getting memory info: {e}", flush=True)
+
+    # GPU Info
+    print("\nGPU Information:", flush=True)
+    try:
+        # Check MLX GPU settings
+        print("MLX Environment Variables:", flush=True)
+        mlx_vars = {k: v for k, v in os.environ.items() if k.startswith('MLX')}
+        print(json.dumps(mlx_vars, indent=2), flush=True)
+        
+        # Check Metal GPU memory allocation
+        gpu_mem = subprocess.run(['sysctl', 'iogpu'], capture_output=True, text=True)
+        print("GPU Memory Settings:", gpu_mem.stdout, flush=True)
+    except Exception as e:
+        print(f"Error getting GPU info: {e}", flush=True)
+
+    # Process Priority
+    print("\nProcess Priority Information:", flush=True)
+    try:
+        current_process = psutil.Process()
+        print(f"Process Nice Value: {current_process.nice()}", flush=True)
+        # Only try to get ionice if the platform supports it
+        if hasattr(current_process, 'ionice'):
+            print(f"Process IO Nice Value: {current_process.ionice()}", flush=True)
+    except Exception as e:
+        print(f"Error getting process priority info: {e}", flush=True)
+
+    # System Load
+    print("\nSystem Load:", flush=True)
+    try:
+        load_avg = psutil.getloadavg()
+        print(f"Load Average: {load_avg}", flush=True)
+        
+        # Get top processes by CPU and Memory
+        print("\nTop Processes by CPU Usage:", flush=True)
+        processes = []
+        for proc in psutil.process_iter(['pid', 'name', 'cpu_percent', 'memory_percent']):
+            try:
+                pinfo = proc.info
+                if pinfo['cpu_percent'] is not None and pinfo['memory_percent'] is not None:
+                    processes.append(pinfo)
+            except (psutil.NoSuchProcess, psutil.AccessDenied):
+                continue
+        
+        # Sort and display top 5 CPU-consuming processes
+        sorted_by_cpu = sorted(processes, key=lambda x: x['cpu_percent'] or 0, reverse=True)[:5]
+        for proc in sorted_by_cpu:
+            print(f"PID: {proc['pid']}, Name: {proc['name']}, CPU: {proc['cpu_percent']}%, Memory: {proc['memory_percent']:.1f}%")
+    except Exception as e:
+        print(f"Error getting system load info: {e}", flush=True)
+
+    print("\n=== End System State Check ===\n", flush=True)
+
+
+def check_gpu_access():
+    try:
+        # Check if MLX can see the GPU
+        import mlx.core as mx
+        print("MLX device info:", mx.default_device())
+        
+        # Check Metal device availability
+        result = subprocess.run(['system_profiler', 'SPDisplaysDataType'], capture_output=True, text=True)
+        print("GPU Info:", result.stdout)
+    except Exception as e:
+        print(f"Failed to check GPU access: {e}")
+
+
+async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
+    """
+    Measures the performance of an API endpoint by sending a prompt and recording metrics.
+
+    Args:
+        api_endpoint (str): The API endpoint URL.
+        prompt (str): The prompt to send to the API.
+
+    Returns:
+        Dict[str, Any]: A dictionary containing performance metrics or error information.
+    """
+
+    results = {
+        'model': model,
+        'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
+        'branch': os.environ.get('GITHUB_REF_NAME', 'unknown'),
+        'commit': os.environ.get('GITHUB_SHA', 'unknown'),
+        'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
+    }
+
+    # Get token count
+    session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600, connect=10, sock_read=600, sock_connect=10))
+    try:
+        response = await session.post(
+            "http://localhost:52415/v1/chat/token/encode",
+            json={
+                "model": model,
+                "messages": [{"role": "user", "content": prompt}]
+            }
+        )
+        response.raise_for_status()
+        token_data = await response.json()
+        results['prompt_len'] = token_data['num_tokens']
+    except Exception as e:
+        await session.close()
+        raise RuntimeError(f"Failed to get token count: {str(e)}")
+
+    # Measure completion performance
+    try:
+        start_time = time.time()
+        response = await session.post(
+            api_endpoint,
+            json={
+                "model": model,
+                "messages": [{"role": "user", "content": prompt}],
+                "temperature": 0,
+                "stream": True
+            }
+        )
+        response.raise_for_status()
+
+        first_token_time = None
+        total_tokens = 0
+
+        async for line in response.content.iter_chunks():
+            line = line[0].decode('utf-8').strip()
+            if not line.startswith('data: '):
+                continue
+
+            data = json.loads(line[6:])  # Skip 'data: ' prefix
+            if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
+                print(f"Received content: {content}", flush=True)
+                if first_token_time is None:
+                    first_token_time = time.time()
+                    ttft = first_token_time - start_time
+                    results.update({
+                        'ttft': ttft,
+                        'prompt_tps': results['prompt_len'] / ttft
+                    })
+                total_tokens += 1
+
+        total_time = time.time() - start_time
+        results.update({
+            'generation_tps': total_tokens / total_time,
+            'response_len': total_tokens,
+            'total_time': total_time
+        })
+
+    except Exception as e:
+        raise RuntimeError(f"Performance measurement failed: {str(e)}")
+    finally:
+        await session.close()
+
+    return results
+
+
+async def main() -> None:
+    api_endpoint = "http://localhost:52415/v1/chat/completions"
+
+    # Define prompts
+    prompt_warmup = "what is the capital of France?"
+    prompt_essay = "write an essay about cats"
+
+    model = os.environ.get('model', 'llama-3.2-1b')
+    # Warmup request
+    print("\nPerforming warmup request...", flush=True)
+    try:
+        warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
+        print("Warmup completed successfully", flush=True)
+    except Exception as e:
+        print(f"Warmup request failed: {e}", flush=True)
+
+    # Measure performance for the essay prompt
+    print("\nMeasuring performance for the essay prompt...", flush=True)
+    results = await measure_performance(api_endpoint, prompt_essay, model)
+
+    try:
+        s3_client = boto3.client(
+            's3',
+            aws_access_key_id=os.environ.get('aws_access_key_id'),
+            aws_secret_access_key=os.environ.get('aws_secret_key')
+        )
+        job_name = os.environ.get('GITHUB_JOB')
+
+        # Create S3 key with timestamp and commit info
+        now = datetime.utcnow()
+        timestamp = now.strftime('%H-%M-%S')
+        commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
+        s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
+
+        # Upload to S3
+        s3_client.put_object(
+            Bucket='exo-benchmarks',
+            Key=s3_key,
+            Body=json.dumps(results),
+            ContentType='application/json'
+        )
+        print(f"Performance metrics uploaded to S3: s3://exo-benchmarks/{s3_key}", flush=True)
+    except Exception as e:
+        print(f"Failed to upload metrics to S3: {e}", flush=True)
+
+    # Optionally print the metrics for visibility
+    print("Performance metrics:", flush=True)
+    print(json.dumps(results, indent=4), flush=True)
+
+
+def optimize_system_performance():
+    """Set optimal system performance settings before running benchmark."""
+    try:
+        # Try to set high performance power mode
+        subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
+        
+        # Ensure MLX uses performance cores and GPU
+        os.environ['MLX_FORCE_P_CORES'] = '1'
+        os.environ['MLX_METAL_PREWARM'] = '1'
+        os.environ['MLX_USE_GPU'] = '1'
+        
+        # Set process priority
+        current_process = psutil.Process()
+        try:
+            # Set highest priority
+            subprocess.run(['sudo', 'renice', '-n', '-20', '-p', str(current_process.pid)], check=False)
+            
+            # Print current process state
+            print("\nProcess State Before Benchmark:", flush=True)
+            proc_info = subprocess.run(
+                ['ps', '-o', 'pid,ppid,user,%cpu,%mem,nice,stat,pri,command', '-p', str(current_process.pid)],
+                capture_output=True, text=True
+            )
+            print(proc_info.stdout, flush=True)
+            
+            # Verify power mode
+            power_info = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
+            if 'powermode            0' in power_info.stdout:
+                print("\nWarning: System still in normal power mode. Trying to set high performance mode again...", flush=True)
+                subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
+            
+        except Exception as e:
+            print(f"Warning: Could not set process priority: {e}", flush=True)
+            
+    except Exception as e:
+        print(f"Warning: Could not optimize system performance: {e}", flush=True)
+    
+    # Print optimization status
+    print("\nOptimization Settings:", flush=True)
+    print("MLX Environment Variables:", flush=True)
+    for var in ['MLX_FORCE_P_CORES', 'MLX_METAL_PREWARM', 'MLX_USE_GPU']:
+        print(f"{var}: {os.environ.get(var, 'Not set')}", flush=True)
+    
+    try:
+        nice_value = psutil.Process().nice()
+        print(f"Process Nice Value: {nice_value}", flush=True)
+        if nice_value != -20:
+            print("Warning: Process not running at highest priority", flush=True)
+    except Exception:
+        pass
+
+
+if __name__ == "__main__":
+    check_system_state()
+    check_gpu_access()
+    optimize_system_performance()
+    asyncio.run(main())

+ 330 - 0
.github/bootstrap.sh

@@ -0,0 +1,330 @@
+#!/bin/bash
+set -e
+
+command_exists() {
+    command -v "$1" >/dev/null 2>&1
+}
+
+log() {
+    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
+}
+
+if [ "$EUID" -eq 0 ]; then 
+    log "Please do not run as root. Run as regular user with sudo access."
+    exit 1
+fi
+
+# Check for required arguments
+if [ -z "$1" ]; then
+    log "Error: Runner token is required"
+    log "Usage: $0 <runner-token> [tailscale-auth-key]"
+    exit 1
+fi
+
+RUNNER_TOKEN=$1
+TAILSCALE_AUTH_KEY=$2
+REPO="exo-explore/exo"
+
+# Add sudoers configuration
+log "Configuring sudo access..."
+SUDOERS_CONTENT="$(whoami) ALL=(ALL) NOPASSWD: ALL"
+echo "$SUDOERS_CONTENT" | sudo tee /etc/sudoers.d/github-runner > /dev/null
+sudo chmod 440 /etc/sudoers.d/github-runner
+
+log "Configuring privacy permissions..."
+sudo tccutil reset All
+sudo tccutil reset SystemPolicyAllFiles
+sudo tccutil reset SystemPolicyNetworkVolumes
+
+# Configure power management for maximum performance
+log "Configuring power management..."
+sudo pmset -a powermode 2  # Force highest performance mode
+sudo pmset -a gpuswitch 2  # Force discrete/high-performance GPU
+sudo pmset -a lowpowermode 0
+sudo pmset -a lessbright 0
+sudo pmset -a disablesleep 1
+sudo pmset -a sleep 0
+sudo pmset -a hibernatemode 0
+sudo pmset -a autopoweroff 0
+sudo pmset -a standby 0
+sudo pmset -a powernap 0
+
+# For Python specifically
+PYTHON_PATH="/opt/homebrew/bin/python3.12"
+sudo chmod 755 "$PYTHON_PATH"
+
+# Add to firewall
+log "Configuring firewall access..."
+sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$PYTHON_PATH"
+sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$PYTHON_PATH"
+
+# Set Homebrew paths based on architecture
+if [ "$(uname -p)" = "arm" ]; then
+    BREW_PREFIX="/opt/homebrew"
+else
+    BREW_PREFIX="/usr/local"
+fi
+
+# Install Homebrew if not present
+if ! command_exists brew; then
+    log "Installing Homebrew..."
+    /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
+    echo 'eval "$(/opt/homebrew/bin/brew shellenv)"' >> ~/.zshrc
+    eval "$(/opt/homebrew/bin/brew shellenv)"
+fi
+
+# Install required packages
+log "Installing required packages..."
+export HOMEBREW_NO_AUTO_UPDATE=1
+brew install python@3.12 coreutils
+
+# Optional Tailscale setup if auth key is provided
+if [ -n "$TAILSCALE_AUTH_KEY" ]; then
+    log "Installing and configuring Tailscale..."
+    brew install --quiet tailscale
+    sudo brew services stop tailscale 2>/dev/null || true
+    sudo rm -f /var/db/tailscale/tailscaled.state 2>/dev/null || true
+    sudo brew services start tailscale
+    sleep 2
+    sudo tailscale up --authkey=$TAILSCALE_AUTH_KEY
+
+    # Enable SSH and Screen Sharing
+    log "Enabling remote access services..."
+    sudo launchctl load -w /System/Library/LaunchDaemons/ssh.plist
+    sudo /System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart \
+        -activate \
+        -configure -access -on \
+        -configure -allowAccessFor -allUsers \
+        -configure -restart -agent -privs -all
+
+    # Create launch daemon for remote access
+    sudo bash -c 'cat > /Library/LaunchDaemons/com.remote.access.setup.plist' << 'EOL'
+<?xml version="1.0" encoding="UTF-8"?>
+<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
+<plist version="1.0">
+<dict>
+    <key>Label</key>
+    <string>com.remote.access.setup</string>
+    <key>ProgramArguments</key>
+    <array>
+        <string>/bin/bash</string>
+        <string>-c</string>
+        <string>
+            launchctl load -w /System/Library/LaunchDaemons/ssh.plist;
+            /System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart -activate -configure -access -on
+        </string>
+    </array>
+    <key>RunAtLoad</key>
+    <true/>
+</dict>
+</plist>
+EOL
+
+    sudo chmod 644 /Library/LaunchDaemons/com.remote.access.setup.plist
+    sudo launchctl load -w /Library/LaunchDaemons/com.remote.access.setup.plist
+fi
+
+# Configure GitHub Actions Runner
+log "Gathering system metadata..."
+MACHINE_NAME=$(scutil --get ComputerName)
+MACHINE_NAME="runner-$(echo -n "$MACHINE_NAME" | tr '[:upper:]' '[:lower:]' | tr -cd '[:alnum:]-')"
+
+# Enhanced Apple Silicon detection
+MACHINE_INFO=$(system_profiler SPHardwareDataType)
+CHIP_FULL=$(echo "$MACHINE_INFO" | grep "Chip" | cut -d: -f2 | xargs)
+if [[ $CHIP_FULL =~ "Apple" ]]; then
+    CHIP_MODEL=$(echo "$CHIP_FULL" | sed 's/^Apple //' | tr -d ' ' | tr '[:lower:]' '[:upper:]')
+    GPU_CORES=$(ioreg -l | grep "gpu-core-count" | awk -F'= ' '{print $2}')
+    if [ -z "$GPU_CORES" ]; then
+        GPU_CORES="N/A"
+    fi
+else
+    CHIP_MODEL="Intel"
+    GPU_CORES="N/A"
+fi
+
+MEMORY=$(($(sysctl -n hw.memsize) / 1024 / 1024 / 1024))
+
+# Set up GitHub Runner
+RUNNER_DIR="$HOME/actions-runner"
+
+# Check if runner is already configured
+if [ -f "$RUNNER_DIR/.runner" ]; then
+  log "Runner already configured. Stopping existing service..."
+  sudo launchctl unload /Library/LaunchDaemons/com.github.runner.plist 2>/dev/null || true
+fi
+
+# Create runner directory if it doesn't exist
+mkdir -p "$RUNNER_DIR"
+cd "$RUNNER_DIR"
+
+CUSTOM_LABELS="self-hosted,macos,arm64,${CHIP_MODEL}_GPU${GPU_CORES}_${MEMORY}GB"
+
+# Only download and extract if not already present or if forced
+if [ ! -f "$RUNNER_DIR/run.sh" ] || [ "${FORCE_SETUP:-false}" = "true" ]; then
+  log "Downloading GitHub Actions runner..."
+  RUNNER_VERSION=$(curl -s https://api.github.com/repos/actions/runner/releases/latest | grep '"tag_name":' | cut -d'"' -f4)
+  curl -o actions-runner.tar.gz -L "https://github.com/actions/runner/releases/download/${RUNNER_VERSION}/actions-runner-osx-arm64-${RUNNER_VERSION#v}.tar.gz"
+  tar xzf actions-runner.tar.gz
+  rm actions-runner.tar.gz
+else
+  log "Runner already downloaded, skipping download step"
+fi
+
+log "Configuring runner with labels: $CUSTOM_LABELS"
+./config.sh --unattended \
+    --url "https://github.com/${REPO}" \
+    --token "${RUNNER_TOKEN}" \
+    --name "${MACHINE_NAME}" \
+    --labels "${CUSTOM_LABELS}" \
+    --work "_work"
+
+# Set optimal performance settings
+log "Configuring system for optimal performance..."
+
+# Configure CPU performance
+log "Setting CPU performance controls..."
+# Disable timer coalescing
+sudo sysctl -w kern.timer.coalescing_enabled=0
+sudo sysctl -w kern.timer_coalesce_bg_scale=-5
+sudo sysctl -w kern.timer_resort_threshold_ns=0
+# Set minimum timer intervals
+sudo sysctl -w kern.wq_max_timer_interval_usecs=1000
+sudo sysctl -w kern.timer_coalesce_bg_ns_max=1000
+# Set minimum timer coalescing for all tiers
+sudo sysctl -w kern.timer_coalesce_tier0_scale=-5
+sudo sysctl -w kern.timer_coalesce_tier0_ns_max=1000
+sudo sysctl -w kern.timer_coalesce_tier1_scale=-5
+sudo sysctl -w kern.timer_coalesce_tier1_ns_max=1000
+sudo sysctl -w kern.timer_coalesce_tier2_scale=-5
+sudo sysctl -w kern.timer_coalesce_tier2_ns_max=1000
+sudo sysctl -w kern.timer_coalesce_tier3_scale=-5
+sudo sysctl -w kern.timer_coalesce_tier3_ns_max=1000
+sudo sysctl -w kern.timer_coalesce_tier4_scale=-5
+sudo sysctl -w kern.timer_coalesce_tier4_ns_max=1000
+# Disable QoS restrictions
+sudo sysctl -w net.qos.policy.restricted=0
+sudo sysctl -w net.qos.policy.restrict_avapps=0
+sudo sysctl -w net.qos.policy.wifi_enabled=0
+sudo sysctl -w net.qos.policy.capable_enabled=0
+# Set scheduler parameters
+sudo sysctl -w kern.sched_rt_avoid_cpu0=0
+sudo sysctl -w debug.sched=2
+sudo sysctl -w net.pktsched.netem.sched_output_ival_ms=1
+
+# Clean up any existing runner services
+log "Cleaning up existing runner services..."
+for service in com.github.runner com.github.runner.monitor com.github.runner.cpuaffinity com.github.runner.affinity; do
+    sudo launchctl bootout system/$service 2>/dev/null || true
+    sudo rm -f /Library/LaunchDaemons/$service.plist
+done
+
+# Create a simple runner service configuration
+sudo tee /Library/LaunchDaemons/com.github.runner.plist > /dev/null << EOF
+<?xml version="1.0" encoding="UTF-8"?>
+<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
+<plist version="1.0">
+    <dict>
+        <key>Label</key>
+        <string>com.github.runner</string>
+        <key>UserName</key>
+        <string>$(whoami)</string>
+        <key>GroupName</key>
+        <string>staff</string>
+        <key>WorkingDirectory</key>
+        <string>$RUNNER_DIR</string>
+        <key>ProgramArguments</key>
+        <array>
+            <string>$RUNNER_DIR/run.sh</string>
+        </array>
+        <key>RunAtLoad</key>
+        <true/>
+        <key>KeepAlive</key>
+        <dict>
+            <key>SuccessfulExit</key>
+            <false/>
+            <key>Crashed</key>
+            <true/>
+        </dict>
+        <key>ProcessType</key>
+        <string>Interactive</string>
+        <key>LowPriorityIO</key>
+        <false/>
+        <key>AbandonProcessGroup</key>
+        <false/>
+        <key>EnableTransactions</key>
+        <true/>
+        <key>ThrottleInterval</key>
+        <integer>0</integer>
+        <key>HardResourceLimits</key>
+        <dict>
+            <key>NumberOfFiles</key>
+            <integer>524288</integer>
+            <key>MemoryLock</key>
+            <integer>-1</integer>
+        </dict>
+        <key>SoftResourceLimits</key>
+        <dict>
+            <key>NumberOfFiles</key>
+            <integer>524288</integer>
+            <key>MemoryLock</key>
+            <integer>-1</integer>
+        </dict>
+        <key>QOSClass</key>
+        <string>User-Interactive</string>
+        <key>StandardOutPath</key>
+        <string>$RUNNER_DIR/_diag/runner.log</string>
+        <key>StandardErrorPath</key>
+        <string>$RUNNER_DIR/_diag/runner.err</string>
+        <key>EnvironmentVariables</key>
+        <dict>
+            <key>PATH</key>
+            <string>/usr/local/bin:/opt/homebrew/bin:/usr/bin:/bin:/usr/sbin:/sbin</string>
+        </dict>
+        <key>Nice</key>
+        <integer>-20</integer>
+    </dict>
+</plist>
+EOF
+
+# Set proper permissions for the LaunchDaemon
+sudo chown root:wheel /Library/LaunchDaemons/com.github.runner.plist
+sudo chmod 644 /Library/LaunchDaemons/com.github.runner.plist
+
+# Remove any existing service
+sudo launchctl bootout system/com.github.runner 2>/dev/null || true
+
+# Load the new service using bootstrap
+sudo launchctl bootstrap system /Library/LaunchDaemons/com.github.runner.plist
+
+# Add Runner.Listener permissions (after runner installation)
+RUNNER_PATH="$RUNNER_DIR/bin/Runner.Listener"
+sudo chmod 755 "$RUNNER_PATH"
+sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$RUNNER_PATH"
+sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$RUNNER_PATH"
+
+# Create connection info file if Tailscale is configured
+if [ -n "$TAILSCALE_AUTH_KEY" ]; then
+    TAILSCALE_IP=$(tailscale ip)
+    cat > "$HOME/remote_access_info.txt" << EOL
+Mac Remote Access Information
+============================
+Computer Name: $MACHINE_NAME
+Username: $USER
+Tailscale IP: $TAILSCALE_IP
+
+SSH Command: ssh $USER@$TAILSCALE_IP
+Screen Sharing: vnc://$TAILSCALE_IP
+EOL
+    chmod 600 "$HOME/remote_access_info.txt"
+fi
+
+log "Verifying runner service status..."
+if sudo launchctl list | grep com.github.runner > /dev/null; then
+    log "GitHub Actions runner service is running successfully!"
+    log "Runner labels: $CUSTOM_LABELS"
+    [ -n "$TAILSCALE_AUTH_KEY" ] && log "Remote access details saved to: $HOME/remote_access_info.txt"
+else
+    log "Error: Failed to start GitHub Actions runner service"
+    exit 1
+fi

+ 95 - 0
.github/optimize_performance.sh

@@ -0,0 +1,95 @@
+#!/bin/bash
+set -e
+
+# Function to log with timestamp
+log() {
+  echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
+}
+
+log "Applying comprehensive performance optimizations..."
+
+# System-wide power management
+log "Configuring power management..."
+sudo pmset -a lessbright 0
+sudo pmset -a disablesleep 1
+sudo pmset -a sleep 0
+sudo pmset -a hibernatemode 0
+sudo pmset -a autopoweroff 0
+sudo pmset -a standby 0
+sudo pmset -a powernap 0
+sudo pmset -a proximitywake 0
+sudo pmset -a tcpkeepalive 1
+sudo pmset -a powermode 2
+sudo pmset -a gpuswitch 2
+sudo pmset -a displaysleep 0
+sudo pmset -a disksleep 0
+
+# Memory and kernel optimizations
+log "Configuring memory and kernel settings..."
+sudo sysctl -w kern.memorystatus_purge_on_warning=0
+sudo sysctl -w kern.memorystatus_purge_on_critical=0
+sudo sysctl -w kern.timer.coalescing_enabled=0
+
+# Metal and GPU optimizations
+log "Configuring Metal and GPU settings..."
+defaults write com.apple.CoreML MPSEnableGPUValidation -bool false
+defaults write com.apple.CoreML MPSEnableMetalValidation -bool false
+defaults write com.apple.CoreML MPSEnableGPUDebug -bool false
+defaults write com.apple.Metal GPUDebug -bool false
+defaults write com.apple.Metal GPUValidation -bool false
+defaults write com.apple.Metal MetalValidation -bool false
+defaults write com.apple.Metal MetalCaptureEnabled -bool false
+defaults write com.apple.Metal MTLValidationBehavior -string "Disabled"
+defaults write com.apple.Metal EnableMTLDebugLayer -bool false
+defaults write com.apple.Metal MTLDebugLevel -int 0
+defaults write com.apple.Metal PreferIntegratedGPU -bool false
+defaults write com.apple.Metal ForceMaximumPerformance -bool true
+defaults write com.apple.Metal MTLPreferredDeviceGPUFrame -bool true
+
+# Create MPS cache directory with proper permissions
+sudo mkdir -p /tmp/mps_cache
+sudo chmod 777 /tmp/mps_cache
+
+# Process and resource limits
+log "Configuring process limits..."
+sudo launchctl limit maxfiles 524288 524288
+ulimit -n 524288 || log "Warning: Could not set file descriptor limit"
+ulimit -c 0
+ulimit -l unlimited || log "Warning: Could not set memory lock limit"
+
+# Export performance-related environment variables
+cat << 'EOF' > /tmp/performance_env.sh
+# Metal optimizations
+export MTL_DEBUG_LAYER=0
+export METAL_DEVICE_WRAPPER_TYPE=1
+export METAL_DEBUG_ERROR_MODE=0
+export METAL_FORCE_PERFORMANCE_MODE=1
+export METAL_DEVICE_PRIORITY=high
+export METAL_MAX_COMMAND_QUEUES=1024
+export METAL_LOAD_LIMIT=0
+export METAL_VALIDATION_ENABLED=0
+export METAL_ENABLE_VALIDATION_LAYER=0
+export OBJC_DEBUG_MISSING_POOLS=NO
+export MPS_CACHEDIR=/tmp/mps_cache
+
+# MLX optimizations
+export MLX_USE_GPU=1
+export MLX_METAL_COMPILE_ASYNC=1
+export MLX_METAL_PREALLOCATE=1
+export MLX_METAL_MEMORY_GUARD=0
+export MLX_METAL_CACHE_KERNELS=1
+export MLX_PLACEMENT_POLICY=metal
+export MLX_METAL_VALIDATION=0
+export MLX_METAL_DEBUG=0
+export MLX_FORCE_P_CORES=1
+export MLX_METAL_MEMORY_BUDGET=0
+export MLX_METAL_PREWARM=1
+
+# Python optimizations
+export PYTHONUNBUFFERED=1
+export PYTHONOPTIMIZE=2
+export PYTHONHASHSEED=0
+export PYTHONDONTWRITEBYTECODE=1
+EOF
+
+log "Performance optimizations completed. Environment variables written to /tmp/performance_env.sh"

+ 206 - 0
.github/workflows/bench_job.yml

@@ -0,0 +1,206 @@
+# This is the reusable workflow file
+name: Distributed Job Runner
+
+on:
+  workflow_call:
+    inputs:
+      config:
+        required: true
+        type: string
+      model:
+        required: true
+        type: string
+      calling_job_name:
+        required: true
+        type: string
+      network_interface:
+        required: true
+        type: string
+jobs:
+  generate-matrix:
+    runs-on: ubuntu-latest
+    outputs:
+      matrix: ${{ steps.set-matrix.outputs.matrix }}
+    steps:
+      - id: set-matrix
+        env:
+          CONFIG: ${{ inputs.config }}
+        run: |
+          MATRIX=$(echo $CONFIG | jq -c '{cpu: [to_entries | .[] | .key as $k | range(.value) | $k]}')
+          echo "matrix=$MATRIX" >> $GITHUB_OUTPUT
+
+  run-distributed-job:
+    needs: generate-matrix
+    strategy:
+      matrix: ${{fromJson(needs.generate-matrix.outputs.matrix)}}
+    runs-on: ['self-hosted', 'macOS', '${{ matrix.cpu }}']
+    env:
+      HARDWARE_CONFIG: ${{ inputs.config }}
+      model: ${{ inputs.model }}
+      # Add performance-related environment variables
+      MTL_DEBUG_LAYER: 0
+      METAL_VALIDATION_ENABLED: 0
+      MLX_METAL_VALIDATION: 0
+      MLX_METAL_DEBUG: 0
+      MLX_FORCE_P_CORES: 1
+      MLX_METAL_PREWARM: 1
+      PYTHONOPTIMIZE: 2
+    steps:
+      - name: Cleanup workspace
+        run: |
+          sudo rm -rf "$GITHUB_WORKSPACE"
+          sudo mkdir -p "$GITHUB_WORKSPACE"
+          sudo chown -R $(whoami):$(id -g) "$GITHUB_WORKSPACE"
+
+      - uses: actions/checkout@v4
+
+      - name: Install dependencies
+        run: |
+          export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
+          python3.12 -m venv .venv || {
+            echo "Failed to find python3.12. Checking installation locations:"
+            ls -l /usr/local/bin/python* /opt/homebrew/bin/python* 2>/dev/null || true
+            exit 1
+          }
+          source .venv/bin/activate
+          pip install --upgrade pip
+          pip install -e .
+          pip install boto3==1.35.76
+
+      - name: Apply Performance Optimizations
+        run: |
+          # Export performance-related environment variables
+          cat << 'EOF' > /tmp/performance_env.sh
+          # MLX and Metal optimizations
+          export MTL_DEBUG_LAYER=0
+          export METAL_VALIDATION_ENABLED=0
+          export MLX_METAL_VALIDATION=0
+          export MLX_METAL_DEBUG=0
+          export MLX_FORCE_P_CORES=1
+          export MLX_METAL_PREWARM=1
+          export PYTHONOPTIMIZE=2
+          EOF
+          
+          # Source the performance environment variables
+          source /tmp/performance_env.sh
+
+          # MLX Memory Settings
+          ./configure_mlx.sh
+          
+          # Verify optimizations
+          echo "Verifying performance settings..."
+          env | grep -E "MLX_|METAL_|MTL_"
+
+      - name: Run exo
+        env:
+          aws_access_key_id: ${{ secrets.S3_EXO_BENCHMARKS_AWS_ACCESS_KEY_ID }}
+          aws_secret_key: ${{ secrets.S3_EXO_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
+        run: |
+          # Source performance environment variables
+          source /tmp/performance_env.sh
+          
+          # Debug information
+          echo "Current commit SHA: $GITHUB_SHA"
+          git rev-parse HEAD
+          git status
+          
+          CALLING_JOB="${{ inputs.calling_job_name }}"
+          UNIQUE_JOB_ID="${CALLING_JOB}_${model}_${GITHUB_RUN_ID}"
+          ALL_NODE_IDS=$(for i in $(seq ${{ strategy.job-total }} -1 0); do echo -n "${UNIQUE_JOB_ID}_${i},"; done | sed 's/,$//')
+          MY_NODE_ID="${UNIQUE_JOB_ID}_${{ strategy.job-index }}"
+          
+          source .venv/bin/activate
+          export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
+          
+          echo "=== Before starting exo ==="
+          ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
+          ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep -i python
+          
+          echo "Starting exo daemon..."
+          
+          echo "Power mode settings:"
+          sudo pmset -g
+          
+          # Start exo with explicit process control
+          sudo taskpolicy -d default -g default -a -t 0 -l 0 .venv/bin/exo \
+            --node-id="${MY_NODE_ID}" \
+            --node-id-filter="${ALL_NODE_IDS}" \
+            --interface-type-filter="${{ inputs.network_interface }}" \
+            --disable-tui \
+            --max-generate-tokens 250 \
+            --chatgpt-api-port 52415 > output1.log 2>&1 &
+          PID1=$!
+          
+          echo "Exo process started with PID: $PID1"
+          tail -f output1.log &
+          TAIL1=$!
+
+          # Give process time to start
+          sleep 2
+          
+          # Set additional process priorities
+          sudo renice -n -20 -p $PID1
+          sudo taskpolicy -t 4 -p $PID1
+          
+          echo "=== After starting exo ==="
+          ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
+          ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep $PID1
+          
+          echo "Additional process details:"
+          sudo powermetrics -n 1 -i 1000 --show-process-energy | grep -A 5 $PID1 || true
+
+          trap 'kill $TAIL1' EXIT
+          trap 'kill $PID1' EXIT
+
+          echo "Waiting for all nodes to connect..."
+          for i in {1..20}; do
+            echo "Attempt $i: Checking node count..."
+            nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
+            echo "Current node count: $nodes"
+            if [ "$nodes" -eq "${{ strategy.job-total }}" ]; then
+              echo "All nodes connected successfully!"
+              break
+            fi
+            if [ $i -eq 20 ]; then
+              echo "ERROR: Failed to connect all nodes after 20 attempts. Expected ${{ strategy.job-total }} nodes, but got $nodes"
+              exit 1
+            fi
+            sleep 5
+          done
+
+          if ! kill -0 $PID1 2>/dev/null; then
+              echo "ERROR: Instance (PID $PID1) died unexpectedly. Full log output:"
+              cat output1.log
+              exit 1
+          fi
+
+          if [ "${{ strategy.job-index }}" -eq "0" ]; then
+            sleep 10
+            echo "This is the primary node (index 0). Running benchmark..."
+            GITHUB_JOB=$CALLING_JOB python .github/bench.py
+          else
+            echo "This is a secondary node (index ${{ strategy.job-index }}). Waiting for completion..."
+            sleep 10
+            while true; do
+              echo "Checking if primary node is still running..."
+              nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
+              echo "Current node count: $nodes"
+              if [ "$nodes" -lt "${{ strategy.job-total }}" ]; then
+                echo "Primary node completed, exiting..."
+                break
+              fi
+              sleep 5
+            done
+          fi
+
+      - name: Check Final System State
+        if: always()
+        run: |
+          echo "=== Final System State ==="
+          sudo pmset -g
+          sudo powermetrics -n 1 -i 1000 --show-process-energy || true
+          system_profiler SPDisplaysDataType
+          sysctl iogpu
+          ps -eo pid,ppid,user,%cpu,%mem,nice,state,command | grep -i python
+          env | grep -E "MLX_|METAL_|MTL_"
+          echo "=== End Final System State ==="

+ 71 - 0
.github/workflows/benchmarks.yml

@@ -0,0 +1,71 @@
+name: Build and Test
+
+on:
+  push:
+    branches: [ '*' ]
+    tags: [ '*' ]
+  pull_request:
+    branches: [ '*' ]
+
+jobs:
+  single-m4-pro:
+    strategy:
+      matrix:
+        model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
+    uses: ./.github/workflows/bench_job.yml
+    with:
+      config: '{"M4PRO_GPU16_24GB": 1}'
+      model: ${{ matrix.model }}
+      calling_job_name: 'single-m4-pro'
+      network_interface: 'Ethernet'
+    secrets: inherit
+
+  two-m4-pro-cluster:
+    strategy:
+      matrix:
+        model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
+    uses: ./.github/workflows/bench_job.yml
+    with:
+      config: '{"M4PRO_GPU16_24GB": 2}'
+      model: ${{ matrix.model }}
+      calling_job_name: 'two-m4-pro-cluster'
+      network_interface: 'Ethernet'
+    secrets: inherit
+
+  # two-m4-pro-cluster-thunderbolt:
+  #   strategy:
+  #     matrix:
+  #       model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
+  #   uses: ./.github/workflows/bench_job.yml
+  #   with:
+  #     config: '{"M4PRO_GPU16_24GB": 2}'
+  #     model: ${{ matrix.model }}
+  #     calling_job_name: 'two-m4-pro-cluster-thunderbolt'
+  #     network_interface: 'Thunderbolt'
+  #   secrets: inherit
+
+  three-m4-pro-cluster:
+    strategy:
+      matrix:
+        model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b', 'llama-3.3-70b']
+      fail-fast: false
+    uses: ./.github/workflows/bench_job.yml
+    with:
+      config: '{"M4PRO_GPU16_24GB": 3}'
+      model: ${{ matrix.model }}
+      calling_job_name: 'three-m4-pro-cluster'
+      network_interface: 'Ethernet'
+    secrets: inherit
+
+  # test-m3-single-node:
+  #   strategy:
+  #     matrix:
+  #       model: ['llama-3.2-1b']
+  #     fail-fast: false
+  #   uses: ./.github/workflows/bench_job.yml
+  #   with:
+  #     config: '{"M3MAX_GPU40_128GB": 1}'
+  #     model: ${{ matrix.model }}
+  #     calling_job_name: 'test-m3-cluster'
+  #     network_interface: 'Ethernet'
+  #   secrets: inherit

+ 32 - 7
configure_mlx.sh

@@ -3,16 +3,41 @@
 # Get the total memory in MB
 TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
 
-# Set WIRED_LIMIT_MB to 80%
-WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
-# Set  WIRED_LWM_MB to 70%
-WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
+# Calculate 80% and TOTAL_MEM_GB-5GB in MB
+EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
+MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
+
+# Calculate 70% and TOTAL_MEM_GB-8GB in MB
+SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
+MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
+
+# Set WIRED_LIMIT_MB to higher value
+if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
+  WIRED_LIMIT_MB=$EIGHTY_PERCENT
+else
+  WIRED_LIMIT_MB=$MINUS_5GB
+fi
+
+# Set WIRED_LWM_MB to higher value
+if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
+  WIRED_LWM_MB=$SEVENTY_PERCENT
+else
+  WIRED_LWM_MB=$MINUS_8GB
+fi
 
 # Display the calculated values
 echo "Total memory: $TOTAL_MEM_MB MB"
 echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
 echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
 
-# Apply the values with sysctl
-sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
-sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
+# Apply the values with sysctl, but check if we're already root
+if [ "$EUID" -eq 0 ]; then
+  sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
+  sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
+else
+  # Try without sudo first, fall back to sudo if needed
+  sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
+    sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
+  sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
+    sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
+fi

+ 75 - 71
exo/api/chatgpt_api.py

@@ -33,6 +33,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil
 from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
+from collections import defaultdict
 
 
 class Message:
@@ -199,6 +200,11 @@ class ChatGPTAPI:
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.default_model = default_model or "llama-3.2-1b"
+    self.token_queues = defaultdict(asyncio.Queue)
+
+    # Get the callback system and register our handler
+    self.token_callback = node.on_token.register("chatgpt-api-token-handler")
+    self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
     self.system_prompt = system_prompt
 
     cors = aiohttp_cors.setup(self.app)
@@ -223,6 +229,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
     cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
     cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
+    cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
     cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
 
     # Add static routes
@@ -348,13 +355,13 @@ class ChatGPTAPI:
 
   async def handle_post_chat_completions(self, request):
     data = await request.json()
-    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     chat_request = parse_chat_request(data, self.default_model)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
       chat_request.model = self.default_model
     if not chat_request.model or chat_request.model not in model_cards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       chat_request.model = self.default_model
     shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     if not shard:
@@ -365,7 +372,7 @@ class ChatGPTAPI:
       )
 
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
+    if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
 
     # Add system prompt if set
     if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
@@ -378,28 +385,13 @@ class ChatGPTAPI:
         self.on_chat_completion_request(request_id, chat_request, prompt)
       except Exception as e:
         if DEBUG >= 2: traceback.print_exc()
-    # request_id = None
-    # match = self.prompts.find_longest_prefix(prompt)
-    # if match and len(prompt) > len(match[1].prompt):
-    #     if DEBUG >= 2:
-    #       print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
-    #     request_id = match[1].request_id
-    #     self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
-    #     # remove the matching prefix from the prompt
-    #     prompt = prompt[len(match[1].prompt):]
-    # else:
-    #   request_id = str(uuid.uuid4())
-    #   self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
-
-    callback_id = f"chatgpt-api-wait-response-{request_id}"
-    callback = self.node.on_token.register(callback_id)
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+    if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
 
     try:
       await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
 
-      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
+      if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
 
       if stream:
         response = web.StreamResponse(
@@ -412,62 +404,74 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
-          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
-          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
-          new_tokens = tokens[prev_last_tokens_len:]
-          finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
-                                                                                                                             AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
-          if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
-            new_tokens = new_tokens[:-1]
-            if is_finished:
-              finish_reason = "stop"
-          if is_finished and not finish_reason:
-            finish_reason = "length"
-
-          completion = generate_completion(
-            chat_request,
-            tokenizer,
-            prompt,
-            request_id,
-            new_tokens,
-            stream,
-            finish_reason,
-            "chat.completion",
-          )
-          if DEBUG >= 2: print(f"Streaming completion: {completion}")
-          try:
+        try:
+          # Stream tokens while waiting for inference to complete
+          while True:
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
+            tokens, is_finished = await asyncio.wait_for(
+              self.token_queues[request_id].get(),
+              timeout=self.response_timeout
+            )
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
+
+            eos_token_id = None
+            if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
+            if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
+
+            finish_reason = None
+            if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
+            if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")
+
+            completion = generate_completion(
+              chat_request,
+              tokenizer,
+              prompt,
+              request_id,
+              tokens,
+              stream,
+              finish_reason,
+              "chat.completion",
+            )
+
             await response.write(f"data: {json.dumps(completion)}\n\n".encode())
-          except Exception as e:
-            if DEBUG >= 2: print(f"Error streaming completion: {e}")
-            if DEBUG >= 2: traceback.print_exc()
 
-        def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
+            if is_finished:
+              break
 
-          return _request_id == request_id and is_finished
+          await response.write_eof()
+          return response
 
-        _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
-        if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
-          if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
-          try:
-            await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
-          except asyncio.TimeoutError:
-            print("WARNING: Stream task timed out. This should not happen.")
-        await response.write_eof()
-        return response
-      else:
-        _, tokens, _ = await callback.wait(
-          lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
-          timeout=self.response_timeout,
-        )
+        except asyncio.TimeoutError:
+          if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
+          return web.json_response({"detail": "Response generation timed out"}, status=408)
 
+        except Exception as e:
+          if DEBUG >= 2: 
+            print(f"[ChatGPTAPI] Error processing prompt: {e}")
+            traceback.print_exc()
+          return web.json_response(
+            {"detail": f"Error processing prompt: {str(e)}"},
+            status=500
+          )
+
+        finally:
+          # Clean up the queue for this request
+          if request_id in self.token_queues:
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
+            del self.token_queues[request_id]
+      else:
+        tokens = []
+        while True:
+          _tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
+          tokens.extend(_tokens)
+          if is_finished:
+            break
         finish_reason = "length"
-        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
+        eos_token_id = None
+        if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
+        if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
         if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
         if tokens[-1] == eos_token_id:
-          tokens = tokens[:-1]
           finish_reason = "stop"
 
         return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
@@ -476,9 +480,6 @@ class ChatGPTAPI:
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-    finally:
-      deregistered_callback = self.node.on_token.deregister(callback_id)
-      if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
   async def handle_post_image_generations(self, request):
     data = await request.json()
@@ -678,6 +679,9 @@ class ChatGPTAPI:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
 
+  async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
+    await self.token_queues[request_id].put((tokens, is_finished))
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()

+ 1 - 1
exo/download/hf/hf_helpers.py

@@ -441,7 +441,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
     shard_specific_patterns = set(["*.safetensors"])
-  if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
+  if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
 
 async def get_file_download_percentage(

+ 2 - 1
exo/download/hf/hf_shard_download.py

@@ -159,13 +159,14 @@ class HFShardDownloader(ShardDownloader):
           print(f"Download calculation for {self.current_repo_id}:")
           print(f"Total bytes: {total_bytes}")
           print(f"Downloaded bytes: {downloaded_bytes}")
+        if DEBUG >= 3:
           for file in relevant_files:
             print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
 
       return status
 
     except Exception as e:
-      if DEBUG >= 2:
+      if DEBUG >= 3:
         print(f"Error getting shard download status: {e}")
         traceback.print_exc()
       return None

+ 35 - 9
exo/helpers.py

@@ -14,6 +14,7 @@ from pathlib import Path
 import tempfile
 import json
 from concurrent.futures import ThreadPoolExecutor
+import traceback
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -230,20 +231,21 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
 
 
 def get_all_ip_addresses_and_interfaces():
-  try:
     ip_addresses = []
     for interface in get_if_list():
-      ip = get_if_addr(interface)
-      # Include all addresses, including loopback
-      # Filter out link-local addresses
-      if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
-        # Remove "\\Device\\NPF_" prefix from interface name
+      try:
+        ip = get_if_addr(interface)
+        if ip.startswith("0.0."): continue
         simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
         ip_addresses.append((ip, simplified_interface))
+      except:
+        if DEBUG >= 1: print(f"Failed to get IP address for interface {interface}")
+        if DEBUG >= 1: traceback.print_exc()
+    if not ip_addresses:
+      if DEBUG >= 1: print("Failed to get any IP addresses. Defaulting to localhost.")
+      return [("localhost", "lo")]
     return list(set(ip_addresses))
-  except:
-    if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
-    return [("localhost", "lo")]
+
 
 
 async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
@@ -329,6 +331,30 @@ def is_frozen():
     or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
     or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
 
+async def get_mac_system_info() -> Tuple[str, str, int]:
+    """Get Mac system information using system_profiler."""
+    try:
+        output = await asyncio.get_running_loop().run_in_executor(
+            subprocess_pool,
+            lambda: subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
+        )
+        
+        model_line = next((line for line in output.split("\n") if "Model Name" in line), None)
+        model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
+        
+        chip_line = next((line for line in output.split("\n") if "Chip" in line), None)
+        chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
+        
+        memory_line = next((line for line in output.split("\n") if "Memory" in line), None)
+        memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
+        memory_units = memory_str.split()
+        memory_value = int(memory_units[0])
+        memory = memory_value * 1024 if memory_units[1] == "GB" else memory_value
+        
+        return model_id, chip_id, memory
+    except Exception as e:
+        if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
+        return "Unknown Model", "Unknown Chip", 0
 
 def get_exo_home() -> Path:
   if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"

+ 7 - 0
exo/inference/mlx/perf_improvements.md

@@ -0,0 +1,7 @@
+# Perf improvements
+
+Target: 460 tok/sec
+- removing sample goes from 369 -> 402
+- performance degrades as we generate more tokens
+- make mlx inference engien synchronous, removing thread pool executor: 402 -> 413
+- remove self.on_opaque_status.trigger_all: 413 -> 418

+ 80 - 68
exo/inference/mlx/sharded_inference_engine.py

@@ -1,155 +1,167 @@
 import numpy as np
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.sample_utils import top_p_sampling
+from mlx_lm.sample_utils import top_p_sampling, make_sampler
 import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
 from .sharded_utils import load_shard, get_image_from_str
-from .losses import loss_fns 
+from .losses import loss_fns
 from ..shard import Shard
 from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 import asyncio
-from concurrent.futures import ThreadPoolExecutor
-from functools import partial
 from collections import OrderedDict
 from mlx_lm.models.cache import make_prompt_cache
-
-def sample_logits(
-  logits: mx.array,
-  temp: float = 0.0,
-  top_p: float = 1.0,
-  logit_bias: Optional[Dict[int, float]] = None
-) -> Tuple[mx.array, float]:
-  if logit_bias:
-    indices = mx.array(list(logit_bias.keys()))
-    values = mx.array(list(logit_bias.values()))
-    logits[:, indices] += values
-
-  if temp == 0:
-    token = mx.argmax(logits, axis=-1)
-  else:
-    if top_p > 0 and top_p < 1.0:
-      token = top_p_sampling(logits, top_p, temp)
-    else:
-      token = mx.random.categorical(logits*(1/temp))
-
-  return token
+from concurrent.futures import ThreadPoolExecutor
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
-    self.executor = ThreadPoolExecutor(max_workers=1)
     self.caches = OrderedDict()
+    self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
+    self.sampler = make_sampler(*self.sampler_params)
+    self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
+    self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
+    self.session = {}
+
+  async def _eval_mlx(self, *args):
+    await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
 
   async def poll_state(self, request_id: str, max_caches=2):
     if request_id in self.caches:
       self.caches.move_to_end(request_id)
     else:
-      newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
+      newcache = make_prompt_cache(self.model)
       if len(self.caches) > max_caches:
         self.caches.popitem(last=False)
       self.caches[request_id] = newcache
     return {"cache": self.caches[request_id]}
 
-  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
-    y = mx.array(x)
-    logits = y[:, -1, :]
-    out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
-    return out
+  async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
+    if (temp, top_p, 0.0, 1) != self.sampler_params:
+      self.sampler_params = (temp, top_p, 0.0, 1)
+      self.sampler = make_sampler(*self.sampler_params)
+    logits = mx.array(x)
+    logits = logits[:, -1, :]
+    logprobs = logits - mx.logsumexp(logits, keepdims=True)
+    result = self.sampler(logprobs)
+    await self._eval_mlx(result)
+    return np.asarray(result, dtype=int)
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return np.array(tokens)
+    return np.asarray(
+      await asyncio.get_running_loop().run_in_executor(
+        self._tokenizer_thread,
+        self.tokenizer.encode,
+        prompt
+      )
+    )
 
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
-    return tokens
+    return await asyncio.get_running_loop().run_in_executor(
+      self._tokenizer_thread,
+      self.tokenizer.decode,
+      tokens
+    )
 
   async def save_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
-    await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
+    await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
 
   async def load_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
-    await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
-    
+    await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
+
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
     x = mx.array(input_data)
+
     if self.model.model_type != 'StableDiffusionPipeline':
-      output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
+      output_data = await asyncio.get_running_loop().run_in_executor(
+        self._mlx_thread,
+        lambda: self.model(x, **state, **(inference_state or {}))
+      )
+      inference_state = None
     else:
-      output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
-    output_data = np.array(output_data)
+      result = await asyncio.get_running_loop().run_in_executor(
+        self._mlx_thread,
+        lambda: self.model(x, **state, **(inference_state or {}))
+      )
+      output_data, inference_state = result
+
+    output_data = np.array(output_data, copy=False)
     return output_data, inference_state
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
     await self.ensure_shard(shard)
     await self.save_session('loss', loss_fns[loss])
-    loop = asyncio.get_running_loop()
-    #print(f"evaluate in <- {inputs}")
     x = mx.array(inputs)
     y = mx.array(targets)
     l = mx.array(lengths)
-    score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
-    #print(f"evaluate out -> {score}")
+
+    score = await asyncio.get_running_loop().run_in_executor(
+      self._mlx_thread,
+      lambda: self.session['loss'](self.model, x, y, l)
+    )
     return score
 
   async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
     await self.ensure_shard(shard)
+
     if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
       await self.save_session('train_layers', trainable_layers)
-      self.model.freeze()
-      self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
+      def freeze_unfreeze():
+        self.model.freeze()
+        self.model.apply_to_modules(
+          lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
+        )
+      await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
+
     if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
       await self.save_session('lossname', loss)
       await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
+
     if 'opt' not in self.session:
       await self.save_session('opt', opt(lr))
     return True
 
   async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
-    loop = asyncio.get_running_loop()
-    nothin = await self.ensure_train(shard, loss, opt, lr)
+    await self.ensure_train(shard, loss, opt, lr)
+
     def train_step(inp, tar, lng):
       lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
       gradlayers = grad['model']['layers']
       self.session['opt'].update(self.model, grad)
-      mx.eval(self.model.parameters(), self.session['opt'].state, lval)
-      return lval, gradlayers
+      return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
 
     x = mx.array(inputs)
     y = mx.array(targets)
     l = mx.array(lengths)
+    score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
+      self._mlx_thread,
+      lambda: train_step(x, y, l)
+    )
+    await self._eval_mlx(*eval_args)
 
-    score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
-    #print(f"{score=}")
-      
-    layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
-    #print(layers[0])
-
-    return score, np.array(layers[0]['input_layernorm'])
+    layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
+    first_layer = np.array(layers[0]['input_layernorm'], copy=False)
+    await self._eval_mlx(first_layer)
+    return score, first_layer
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
       return
-
     model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
-
     if self.shard != shard:
-
-      def load_shard_wrapper():
-        return asyncio.run(load_shard(model_path, shard))
-
-      model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
+      model_shard, self.tokenizer = await load_shard(model_path, shard)
       self.shard = shard
-      self.model = model_shard 
+      self.model = model_shard
       self.caches = OrderedDict()
       self.session = {}
 
+  async def cleanup(self):
+    self._mlx_thread.shutdown(wait=True)

+ 81 - 0
exo/inference/mlx/test_non_blocking.py

@@ -0,0 +1,81 @@
+import asyncio
+import time
+import numpy as np
+from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.inference.shard import Shard
+from exo.models import build_base_shard
+from collections import deque
+from statistics import mean, median
+
+async def test_non_blocking():
+    # Setup
+    shard_downloader = HFShardDownloader()
+    engine = MLXDynamicShardInferenceEngine(shard_downloader)
+    _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
+    shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
+    await engine.ensure_shard(shard)
+    
+    queue = asyncio.Queue()
+    measurements = deque(maxlen=1000000)
+    running = True
+
+    async def mlx_worker():
+        try:
+            start_time = time.time()
+            count = 0
+            while running and (time.time() - start_time) < 5:  # Hard time limit
+                start = time.perf_counter_ns()
+                await engine.infer_prompt("req1", shard, "test prompt")
+                duration = (time.perf_counter_ns() - start) / 1_000_000  # Convert to ms
+                count += 1
+                print(f"MLX operation {count} took: {duration:.3f}ms")
+        except asyncio.CancelledError:
+            pass
+        finally:
+            print(f"\nTotal MLX operations completed: {count}")
+            print(f"Average rate: {count/5:.1f} ops/second")
+
+    async def latency_producer():
+        try:
+            start_time = time.perf_counter_ns()
+            count = 0
+            while running:
+                await queue.put(time.perf_counter_ns())
+                count += 1
+                await asyncio.sleep(0)  # Yield to event loop without delay
+            duration = (time.perf_counter_ns() - start_time) / 1e9  # Convert to seconds
+            print(f"\nProducer iterations: {count}")
+            print(f"Producer rate: {count/duration:.1f} iterations/second")
+        except asyncio.CancelledError:
+            pass
+
+    async def latency_consumer():
+        try:
+            while running:
+                timestamp = await queue.get()
+                latency = (time.perf_counter_ns() - timestamp) / 1_000_000  # Convert to ms
+                measurements.append(latency)
+                queue.task_done()
+        except asyncio.CancelledError:
+            pass
+
+    tasks = [
+        asyncio.create_task(mlx_worker()),
+        asyncio.create_task(latency_producer()),
+        asyncio.create_task(latency_consumer())
+    ]
+    
+    try:
+        await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
+    except asyncio.TimeoutError:
+        print("\nTest timed out")
+    finally:
+        running = False
+        for task in tasks:
+            task.cancel()
+        await asyncio.gather(*tasks, return_exceptions=True)
+        print(f"\nFinal measurement count: {len(measurements)}")
+
+if __name__ == "__main__":
+    asyncio.run(test_non_blocking())

+ 79 - 29
exo/main.py

@@ -13,7 +13,6 @@ import uuid
 import numpy as np
 from functools import partial
 from tqdm import tqdm
-from tqdm.asyncio import tqdm_asyncio
 from exo.train.dataset import load_dataset, iterate_batches, compose
 from exo.networking.manual.manual_discovery import ManualDiscovery
 from exo.networking.manual.network_topology_config import NetworkTopology
@@ -33,6 +32,46 @@ from exo.inference.tokenizers import resolve_tokenizer
 from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
+import uvloop
+from contextlib import asynccontextmanager
+import concurrent.futures
+import socket
+import resource
+import psutil
+
+# TODO: figure out why this is happening
+os.environ["GRPC_VERBOSITY"] = "error"
+os.environ["TRANSFORMERS_VERBOSITY"] = "error"
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+# Configure uvloop for maximum performance
+def configure_uvloop():
+    # Install uvloop as event loop policy
+    uvloop.install()
+
+    # Create new event loop
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+
+    # Increase file descriptor limits on Unix systems
+    if not psutil.WINDOWS:
+      soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+      try:
+          resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
+      except ValueError:
+        try:
+          resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
+        except ValueError:
+          pass
+
+    # Configure thread pool for blocking operations
+    loop.set_default_executor(
+      concurrent.futures.ThreadPoolExecutor(
+        max_workers=min(32, (os.cpu_count() or 1) * 4)
+      )
+    )
+
+    return loop
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -52,7 +91,6 @@ parser.add_argument("--models-seed-dir", type=str, default=None, help="Model see
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
-parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
@@ -69,6 +107,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
+parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
 parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
@@ -101,8 +140,9 @@ if DEBUG >= 0:
   for chatgpt_api_endpoint in chatgpt_api_endpoints:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
-# Convert node-id-filter to list if provided
+# Convert node-id-filter and interface-type-filter to lists if provided
 allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
+allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None
 
 if args.discovery_module == "udp":
   discovery = UDPDiscovery(
@@ -112,7 +152,8 @@ if args.discovery_module == "udp":
     args.broadcast_port,
     lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
     discovery_timeout=args.discovery_timeout,
-    allowed_node_ids=allowed_node_ids
+    allowed_node_ids=allowed_node_ids,
+    allowed_interface_types=allowed_interface_types
   )
 elif args.discovery_module == "tailscale":
   discovery = TailscaleDiscovery(
@@ -150,9 +191,16 @@ api = ChatGPTAPI(
   default_model=args.default_model,
   system_prompt=args.system_prompt
 )
-node.on_token.register("update_topology_viz").on_next(
-  lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
-)
+buffered_token_output = {}
+def update_topology_viz(req_id, tokens, __):
+  if not topology_viz: return
+  if not inference_engine.shard: return
+  if inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
+
+  if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
+  else: buffered_token_output[req_id] = tokens
+  topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
+node.on_token.register("update_topology_viz").on_next(update_topology_viz)
 
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
@@ -169,10 +217,6 @@ def preemptively_start_download(request_id: str, opaque_status: str):
 
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 
-if args.prometheus_client_port:
-  from exo.stats.metrics import start_metrics_server
-  start_metrics_server(node, args.prometheus_client_port)
-
 last_broadcast_time = 0
 
 
@@ -204,7 +248,11 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     print(f"Processing prompt: {prompt}")
     await node.process_prompt(shard, prompt, request_id=request_id)
 
-    _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
+    tokens = []
+    def on_token(_request_id, _tokens, _is_finished):
+      tokens.extend(_tokens)
+      return _request_id == request_id and _is_finished
+    await callback.wait(on_token, timeout=300)
 
     print("\nGenerated response:")
     print(tokenizer.decode(tokens))
@@ -223,7 +271,7 @@ def clean_path(path):
 async def hold_outstanding(node: Node):
   while node.outstanding_requests:
     await asyncio.sleep(.5)
-  return 
+  return
 
 async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
   losses = []
@@ -234,7 +282,7 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
     tokens.append(np.sum(lengths))
   total_tokens = np.sum(tokens)
   total_loss = np.sum(losses) / total_tokens
-  
+
   return total_loss, total_tokens
 
 async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
@@ -270,7 +318,7 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
       await hold_outstanding(node)
   await hold_outstanding(node)
 
-  
+
 async def main():
   loop = asyncio.get_running_loop()
 
@@ -285,7 +333,7 @@ async def main():
           {"❌ No read access" if not has_read else ""}
           {"❌ No write access" if not has_write else ""}
           """)
-    
+
   if not args.models_seed_dir is None:
     try:
       models_seed_dir = clean_path(args.models_seed_dir)
@@ -330,29 +378,31 @@ async def main():
         print("Error: This train ain't leaving the station without a model")
         return
       await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
-    
+
   else:
     asyncio.create_task(api.run(port=args.chatgpt_api_port))  # Start the API server as a non-blocking task
     await asyncio.Event().wait()
-  
+
   if args.wait_for_peers > 0:
     print("Cooldown to allow peers to exit gracefully")
     for i in tqdm(range(50)):
       await asyncio.sleep(.1)
 
+@asynccontextmanager
+async def setup_node(args):
+    # Rest of setup_node implementation...
+    pass
 
 def run():
-  loop = asyncio.new_event_loop()
-  asyncio.set_event_loop(loop)
-  try:
-    loop.run_until_complete(main())
-      
-  except KeyboardInterrupt:
-    print("Received keyboard interrupt. Shutting down...")
-  finally:
-    loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
-    loop.close()
-
+    loop = None
+    try:
+        loop = configure_uvloop()
+        loop.run_until_complete(main())
+    except KeyboardInterrupt:
+        print("\nShutdown requested... exiting")
+    finally:
+        if loop:
+            loop.close()
 
 if __name__ == "__main__":
   run()

+ 24 - 18
exo/networking/grpc/grpc_peer_handle.py

@@ -28,6 +28,19 @@ class GRPCPeerHandle(PeerHandle):
     self._device_capabilities = device_capabilities
     self.channel = None
     self.stub = None
+    self.channel_options = [
+      ("grpc.max_metadata_size", 64 * 1024 * 1024),
+      ("grpc.max_receive_message_length", 256 * 1024 * 1024),
+      ("grpc.max_send_message_length", 256 * 1024 * 1024),
+      ("grpc.max_concurrent_streams", 100),
+      ("grpc.http2.min_time_between_pings_ms", 10000),
+      ("grpc.keepalive_time_ms", 20000),
+      ("grpc.keepalive_timeout_ms", 10000),
+      ("grpc.keepalive_permit_without_calls", 1),
+      ("grpc.http2.max_pings_without_data", 0),
+      ("grpc.tcp_nodelay", 1),
+      ("grpc.optimization_target", "throughput"),
+    ]
 
   def id(self) -> str:
     return self._id
@@ -44,7 +57,9 @@ class GRPCPeerHandle(PeerHandle):
   async def connect(self):
     if self.channel is None:
       self.channel = grpc.aio.insecure_channel(
-        self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
+        self.address,
+        options=self.channel_options,
+        compression=grpc.Compression.Gzip
       )
       self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     await self.channel.channel_ready()
@@ -59,7 +74,13 @@ class GRPCPeerHandle(PeerHandle):
     self.stub = None
 
   async def _ensure_connected(self):
-    if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
+    if not await self.is_connected():
+      try:
+        await asyncio.wait_for(self.connect(), timeout=10.0)
+      except asyncio.TimeoutError:
+        if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
+        await self.disconnect()
+        raise
 
   async def health_check(self) -> bool:
     try:
@@ -88,12 +109,7 @@ class GRPCPeerHandle(PeerHandle):
       request_id=request_id,
       inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
     )
-    response = await self.stub.SendPrompt(request)
-
-    if not response.tensor_data or not response.shape or not response.dtype:
-      return None
-
-    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+    await self.stub.SendPrompt(request)
 
   async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
@@ -154,16 +170,6 @@ class GRPCPeerHandle(PeerHandle):
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
-    response = await self.stub.GetInferenceResult(request)
-    if response.tensor is None:
-      return None, response.is_finished
-    return (
-      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
-      response.is_finished,
-    )
-
   async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
     request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
     response = await self.stub.CollectTopology(request)

+ 11 - 3
exo/networking/grpc/grpc_server.py

@@ -27,11 +27,19 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 
   async def start(self) -> None:
     self.server = grpc.aio.server(
-      futures.ThreadPoolExecutor(max_workers=10),
+      futures.ThreadPoolExecutor(max_workers=32),
       options=[
         ("grpc.max_metadata_size", 32*1024*1024),
-        ("grpc.max_send_message_length", 128*1024*1024),
-        ("grpc.max_receive_message_length", 128*1024*1024),
+        ("grpc.max_send_message_length", 256*1024*1024),
+        ("grpc.max_receive_message_length", 256*1024*1024),
+        ("grpc.keepalive_time_ms", 10000),
+        ("grpc.keepalive_timeout_ms", 5000),
+        ("grpc.http2.max_pings_without_data", 0),
+        ("grpc.http2.min_time_between_pings_ms", 10000),
+        ("grpc.http2.min_ping_interval_without_data_ms", 5000),
+        ("grpc.max_concurrent_streams", 100),
+        ("grpc.tcp_nodelay", 1),
+        ("grpc.optimization_target", "throughput"),
       ],
     )
     node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)

+ 0 - 10
exo/networking/grpc/node_service.proto

@@ -6,7 +6,6 @@ service NodeService {
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
   rpc SendExample (ExampleRequest) returns (Loss) {}
-  rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc SendResult (SendResultRequest) returns (Empty) {}
   rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
@@ -47,15 +46,6 @@ message Loss {
   float loss = 1;
   optional Tensor grads = 2;
 }
-  
-message GetInferenceResultRequest {
-  string request_id = 1;
-}
-
-message InferenceResult {
-  optional Tensor tensor = 1;
-  bool is_finished = 2;
-}
 
 message Tensor {
   bytes tensor_data = 1;

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 2 - 2
exo/networking/grpc/node_service_pb2.py


+ 44 - 87
exo/networking/grpc/node_service_pb2_grpc.py

@@ -3,7 +3,7 @@
 import grpc
 import warnings
 
-from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
+from . import node_service_pb2 as node__service__pb2
 
 GRPC_GENERATED_VERSION = '1.68.0'
 GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
 if _version_not_supported:
     raise RuntimeError(
         f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
+        + f' but the generated code in node_service_pb2_grpc.py depends on'
         + f' grpcio>={GRPC_GENERATED_VERSION}.'
         + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
         + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -36,43 +36,38 @@ class NodeServiceStub(object):
         """
         self.SendPrompt = channel.unary_unary(
                 '/node_service.NodeService/SendPrompt',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
+                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendTensor = channel.unary_unary(
                 '/node_service.NodeService/SendTensor',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
+                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
-                _registered_method=True)
-        self.GetInferenceResult = channel.unary_unary(
-                '/node_service.NodeService/GetInferenceResult',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
+                request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Loss.FromString,
                 _registered_method=True)
         self.CollectTopology = channel.unary_unary(
                 '/node_service.NodeService/CollectTopology',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
+                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Topology.FromString,
                 _registered_method=True)
         self.SendResult = channel.unary_unary(
                 '/node_service.NodeService/SendResult',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
+                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.SendOpaqueStatus = channel.unary_unary(
                 '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
+                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.HealthCheck = channel.unary_unary(
                 '/node_service.NodeService/HealthCheck',
-                request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
-                response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
+                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
                 _registered_method=True)
 
 
@@ -97,12 +92,6 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
-    def GetInferenceResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
     def CollectTopology(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -132,43 +121,38 @@ def add_NodeServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
                     servicer.SendPrompt,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
+                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             'SendTensor': grpc.unary_unary_rpc_method_handler(
                     servicer.SendTensor,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
+                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.SerializeToString,
-            ),
-            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.GetInferenceResult,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
+                    request_deserializer=node__service__pb2.ExampleRequest.FromString,
+                    response_serializer=node__service__pb2.Loss.SerializeToString,
             ),
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
                     servicer.CollectTopology,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
+                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+                    response_serializer=node__service__pb2.Topology.SerializeToString,
             ),
             'SendResult': grpc.unary_unary_rpc_method_handler(
                     servicer.SendResult,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
+                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
                     servicer.SendOpaqueStatus,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
+                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             'HealthCheck': grpc.unary_unary_rpc_method_handler(
                     servicer.HealthCheck,
-                    request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
-                    response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
+                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
             ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
@@ -196,8 +180,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendPrompt',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
+            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -223,8 +207,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendTensor',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
+            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.Tensor.FromString,
             options,
             channel_credentials,
             insecure,
@@ -250,35 +234,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendExample',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
-    @staticmethod
-    def GetInferenceResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GetInferenceResult',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
+            node__service__pb2.ExampleRequest.SerializeToString,
+            node__service__pb2.Loss.FromString,
             options,
             channel_credentials,
             insecure,
@@ -304,8 +261,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/CollectTopology',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
+            node__service__pb2.CollectTopologyRequest.SerializeToString,
+            node__service__pb2.Topology.FromString,
             options,
             channel_credentials,
             insecure,
@@ -331,8 +288,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendResult',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
+            node__service__pb2.SendResultRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -358,8 +315,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/SendOpaqueStatus',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
+            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -385,8 +342,8 @@ class NodeService(object):
             request,
             target,
             '/node_service.NodeService/HealthCheck',
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
-            exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
+            node__service__pb2.HealthCheckRequest.SerializeToString,
+            node__service__pb2.HealthCheckResponse.FromString,
             options,
             channel_credentials,
             insecure,

+ 1 - 2
exo/networking/manual/manual_discovery.py

@@ -63,8 +63,7 @@ class ManualDiscovery(Discovery):
             print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
         except Exception as e:
           if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
-      self.known_peers = new_known_peers
-      await asyncio.sleep(1.0)
+      await asyncio.sleep(5.0)
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
 

+ 0 - 4
exo/networking/peer_handle.py

@@ -51,10 +51,6 @@ class PeerHandle(ABC):
   async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     pass
 
-  @abstractmethod
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    pass
-
   @abstractmethod
   async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
     pass

+ 1 - 1
exo/networking/tailscale/tailscale_discovery.py

@@ -40,7 +40,7 @@ class TailscaleDiscovery(Discovery):
     self.update_task = None
 
   async def start(self):
-    self.device_capabilities = device_capabilities()
+    self.device_capabilities = await device_capabilities()
     self.discovery_task = asyncio.create_task(self.task_discover_peers())
     self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
     self.update_task = asyncio.create_task(self.task_update_device_posture_attributes())

+ 42 - 14
exo/networking/udp/udp_discovery.py

@@ -3,7 +3,7 @@ import json
 import socket
 import time
 import traceback
-from typing import List, Dict, Callable, Tuple, Coroutine
+from typing import List, Dict, Callable, Tuple, Coroutine, Optional
 from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
@@ -23,15 +23,29 @@ class ListenProtocol(asyncio.DatagramProtocol):
     asyncio.create_task(self.on_message(data, addr))
 
 
+def get_broadcast_address(ip_addr: str) -> str:
+  try:
+    # Split IP into octets and create broadcast address for the subnet
+    ip_parts = ip_addr.split('.')
+    return f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}.255"
+  except:
+    return "255.255.255.255"
+
+
 class BroadcastProtocol(asyncio.DatagramProtocol):
-  def __init__(self, message: str, broadcast_port: int):
+  def __init__(self, message: str, broadcast_port: int, source_ip: str):
     self.message = message
     self.broadcast_port = broadcast_port
+    self.source_ip = source_ip
 
   def connection_made(self, transport):
     sock = transport.get_extra_info("socket")
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-    transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
+    # Try both subnet-specific and global broadcast
+    broadcast_addr = get_broadcast_address(self.source_ip)
+    transport.sendto(self.message.encode("utf-8"), (broadcast_addr, self.broadcast_port))
+    if broadcast_addr != "255.255.255.255":
+      transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
 
 
 class UDPDiscovery(Discovery):
@@ -45,7 +59,8 @@ class UDPDiscovery(Discovery):
     broadcast_interval: int = 2.5,
     discovery_timeout: int = 30,
     device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
-    allowed_node_ids: List[str] = None,
+    allowed_node_ids: Optional[List[str]] = None,
+    allowed_interface_types: Optional[List[str]] = None,
   ):
     self.node_id = node_id
     self.node_port = node_port
@@ -56,13 +71,14 @@ class UDPDiscovery(Discovery):
     self.discovery_timeout = discovery_timeout
     self.device_capabilities = device_capabilities
     self.allowed_node_ids = allowed_node_ids
+    self.allowed_interface_types = allowed_interface_types
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
     self.broadcast_task = None
     self.listen_task = None
     self.cleanup_task = None
 
   async def start(self):
-    self.device_capabilities = device_capabilities()
+    self.device_capabilities = await device_capabilities()
     self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
     self.listen_task = asyncio.create_task(self.task_listen_for_peers())
     self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
@@ -82,11 +98,7 @@ class UDPDiscovery(Discovery):
     return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
 
   async def task_broadcast_presence(self):
-    if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
-
     while True:
-      # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
-      # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
       for addr, interface_name in get_all_ip_addresses_and_interfaces():
         interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
         message = json.dumps({
@@ -94,16 +106,26 @@ class UDPDiscovery(Discovery):
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": interface_priority,
           "interface_name": interface_name,
           "interface_type": interface_type,
         })
-        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
 
         transport = None
         try:
-          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
-          if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
+          sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+          sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+          sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+          try:
+            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+          except AttributeError:
+            pass
+          sock.bind((addr, 0))
+          
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
+            lambda: BroadcastProtocol(message, self.broadcast_port, addr),
+            sock=sock
+          )
         except Exception as e:
           print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
         finally:
@@ -111,7 +133,7 @@ class UDPDiscovery(Discovery):
             try: transport.close()
             except Exception as e:
               if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
-              if DEBUG_DISCOVERY >= 2: traceback.print_exc()
+
       await asyncio.sleep(self.broadcast_interval)
 
   async def on_listen_message(self, data, addr):
@@ -147,6 +169,12 @@ class UDPDiscovery(Discovery):
       peer_prio = message["priority"]
       peer_interface_name = message["interface_name"]
       peer_interface_type = message["interface_type"]
+
+      # Skip if interface type is not in allowed list
+      if self.allowed_interface_types and peer_interface_type not in self.allowed_interface_types:
+        if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as its interface type {peer_interface_type} is not in the allowed interface types list")
+        return
+
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
 
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":

+ 19 - 51
exo/orchestration/node.py

@@ -8,7 +8,7 @@ from typing import List, Dict, Optional, Tuple, Union, Set
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from exo.topology.topology import Topology
-from exo.topology.device_capabilities import device_capabilities
+from exo.topology.device_capabilities import device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
 from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
@@ -37,7 +37,7 @@ class Node:
     self.partitioning_strategy = partitioning_strategy
     self.peers: List[PeerHandle] = {}
     self.topology: Topology = Topology()
-    self.device_capabilities = device_capabilities()
+    self.device_capabilities = UNKNOWN_DEVICE_CAPABILITIES
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
     self.buffered_logits: Dict[str, List[np.ndarray]] = {}
     self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
@@ -56,6 +56,7 @@ class Node:
     self.outstanding_requests = {}
 
   async def start(self, wait_for_peers: int = 0) -> None:
+    self.device_capabilities = await device_capabilities()
     await self.server.start()
     await self.discovery.start()
     await self.update_peers(wait_for_peers)
@@ -70,25 +71,28 @@ class Node:
   def on_node_status(self, request_id, opaque_status):
     try:
       status_data = json.loads(opaque_status)
-      if status_data.get("type", "") == "supported_inference_engines":
+      status_type = status_data.get("type", "")
+      if status_type == "supported_inference_engines":
         node_id = status_data.get("node_id")
         engines = status_data.get("engines", [])
         self.topology_inference_engines_pool.append(engines)
-      if status_data.get("type", "") == "node_status":
+      elif status_type == "node_status":
         if status_data.get("status", "").startswith("start_"):
           self.current_topology.active_node_id = status_data.get("node_id")
         elif status_data.get("status", "").startswith("end_"):
           if status_data.get("node_id") == self.current_topology.active_node_id:
             self.current_topology.active_node_id = None
+
       download_progress = None
-      if status_data.get("type", "") == "download_progress":
+      if status_type == "download_progress":
         if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
         download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
         self.node_download_progress[status_data.get('node_id')] = download_progress
+
       if self.topology_viz:
         self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
     except Exception as e:
-      if DEBUG >= 1: print(f"Error updating visualization: {e}")
+      if DEBUG >= 1: print(f"Error on_node_status: {e}")
       if DEBUG >= 1: traceback.print_exc()
 
   def get_supported_inference_engines(self):
@@ -107,6 +111,8 @@ class Node:
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
   
+  token_count = 0
+  first_token_time = 0
   async def process_inference_result(
     self,
     shard,
@@ -124,9 +130,8 @@ class Node:
         self.buffered_token_output[request_id][0].append(token.item())
         is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-        asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
         forward = token.reshape(1, -1)
-        intermediate_result = self.buffered_token_output[request_id][0]
+        intermediate_result = [self.buffered_token_output[request_id][0][-1]]
       else:
         forward = result
     else:
@@ -157,6 +162,7 @@ class Node:
     inference_state: Optional[dict] = {},
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
+    start_time = time.perf_counter_ns()
     asyncio.create_task(
       self.broadcast_opaque_status(
         request_id,
@@ -187,18 +193,17 @@ class Node:
           "prompt": prompt,
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
-          "result_size": resp.size if resp is not None else 0,
         }),
       )
     )
-    return resp
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
 
   async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
-
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+
     if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
       self.outstanding_requests[request_id] = "waiting"
@@ -355,41 +360,11 @@ class Node:
     inference_state: Optional[dict] = None,
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "start_process_tensor",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "tensor_size": tensor.size,
-          "tensor_shape": tensor.shape,
-          "request_id": request_id,
-        }),
-      )
-    )
     start_time = time.perf_counter_ns()
     resp = await self._process_tensor(shard, tensor, request_id, inference_state)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "end_process_tensor",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "request_id": request_id,
-          "elapsed_time_ns": elapsed_time_ns,
-          "result_size": resp.size if resp is not None else 0,
-        }),
-      )
-    )
-    return resp
+    if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
 
   async def _process_tensor(
     self,
@@ -402,7 +377,6 @@ class Node:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
 
-    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
       self.outstanding_requests[request_id] = "processing"
       result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
@@ -412,7 +386,6 @@ class Node:
       self.outstanding_requests.pop(request_id)
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
-      return None
   
   async def forward_example(
     self,
@@ -558,18 +531,13 @@ class Node:
       try:
         did_peers_change = await self.update_peers()
         if DEBUG >= 2: print(f"{did_peers_change=}")
+        await self.collect_topology(set())
         if did_peers_change:
-          await self.collect_topology(set())
           await self.select_best_inference_engine()
       except Exception as e:
         print(f"Error collecting topology: {e}")
         traceback.print_exc()
 
-  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-    if request_id not in self.buffered_token_output:
-      return None, False
-    return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
-
   async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
     next_topology = Topology()
     next_topology.update_node(self.id, self.device_capabilities)
@@ -614,7 +582,7 @@ class Node:
     return self._on_opaque_status
 
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
-    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
+    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
     self.on_token.trigger_all(request_id, tokens, is_finished)
   
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:

+ 9 - 0
exo/orchestration/test_node.py

@@ -1,6 +1,7 @@
 import unittest
 from unittest.mock import Mock, AsyncMock
 import numpy as np
+import pytest
 
 from .node import Node
 from exo.networking.peer_handle import PeerHandle
@@ -55,3 +56,11 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
     await self.node.process_tensor(input_tensor, None)
 
     self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)
+
+  @pytest.mark.asyncio
+  async def test_node_capabilities():
+    node = Node()
+    await node.initialize()
+    caps = await node.get_device_capabilities()
+    assert caps is not None
+    assert caps.model != ""

+ 166 - 0
exo/orchestration/tracing.py

@@ -0,0 +1,166 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Any
+from opentelemetry import trace, context
+from opentelemetry.trace import Status, StatusCode, SpanContext
+from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
+from contextlib import contextmanager
+import time
+from threading import Lock
+
+@dataclass
+class TraceContext:
+  request_id: str
+  sequence_number: int
+  current_span: Optional[trace.Span] = None
+  trace_parent: Optional[str] = None
+  token_group_span: Optional[trace.Span] = None
+  token_count: int = 0
+  token_group_size: int = 10  # Default group size
+  request_span: Optional[trace.Span] = None  # Track the main request span
+
+class Tracer:
+  def __init__(self):
+    self.tracer = trace.get_tracer("exo")
+    self.contexts: Dict[str, TraceContext] = {}
+    self._lock = Lock()
+    self.propagator = TraceContextTextMapPropagator()
+    
+  def get_context(self, request_id: str) -> Optional[TraceContext]:
+    with self._lock:
+      return self.contexts.get(request_id)
+
+  def set_context(self, request_id: str, context: TraceContext):
+    with self._lock:
+      self.contexts[request_id] = context
+
+  def inject_context(self, span: trace.Span) -> str:
+    """Inject current span context into carrier for propagation"""
+    carrier = {}
+    ctx = trace.set_span_in_context(span)
+    self.propagator.inject(carrier, context=ctx)
+    return carrier.get("traceparent", "")
+
+  def extract_context(self, trace_parent: str) -> Optional[context.Context]:
+    """Extract span context from carrier"""
+    if not trace_parent:
+      return None
+    carrier = {"traceparent": trace_parent}
+    return self.propagator.extract(carrier)
+
+  def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
+    """Create a new context with the given trace parent"""
+    parent_ctx = self.extract_context(trace_parent)
+    if parent_ctx:
+      # Create a new request span that links to the parent context
+      request_span = self.tracer.start_span(
+        "request",
+        context=parent_ctx,
+        attributes={
+          "request_id": request_id,
+          "sequence_number": sequence_number
+        }
+      )
+      return TraceContext(
+        request_id=request_id,
+        sequence_number=sequence_number,
+        request_span=request_span,
+        current_span=request_span,
+        trace_parent=trace_parent
+      )
+    return TraceContext(request_id=request_id, sequence_number=sequence_number)
+
+  def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
+    """Handle token generation and manage token group spans"""
+    context.token_count += 1
+    
+    # Start a new token group span if needed
+    if not context.token_group_span and context.request_span:
+      group_number = (context.token_count - 1) // context.token_group_size + 1
+      
+      # Create token group span as child of request span
+      parent_ctx = trace.set_span_in_context(context.request_span)
+      context.token_group_span = self.tracer.start_span(
+        f"token_group_{group_number}",
+        context=parent_ctx,
+        attributes={
+          "request_id": context.request_id,
+          "group.number": group_number,
+          "group.start_token": context.token_count,
+          "group.max_tokens": context.token_group_size
+        }
+      )
+    
+    # Add token to current group span
+    if context.token_group_span:
+      relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
+      context.token_group_span.set_attribute(f"token.{relative_pos}", token)
+      context.token_group_span.set_attribute("token.count", relative_pos)
+      
+      # End current group span if we've reached the group size or if generation is finished
+      if context.token_count % context.token_group_size == 0 or is_finished:
+        context.token_group_span.set_attribute("token.final_count", relative_pos)
+        context.token_group_span.end()
+        context.token_group_span = None
+
+  @contextmanager
+  def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
+    """Start a new span with proper parent context"""
+    attributes = {
+      "request_id": context.request_id,
+      "sequence_number": context.sequence_number
+    }
+    if extra_attributes:
+      attributes.update(extra_attributes)
+      
+    # Use request span as parent if available
+    parent_ctx = None
+    if context.request_span:
+      parent_ctx = trace.set_span_in_context(context.request_span)
+    elif context.trace_parent:
+      parent_ctx = self.extract_context(context.trace_parent)
+      if parent_ctx and not context.request_span:
+        # Create a new request span that links to the parent context
+        context.request_span = self.tracer.start_span(
+          "request",
+          context=parent_ctx,
+          attributes={
+            "request_id": context.request_id,
+            "sequence_number": context.sequence_number
+          }
+        )
+        parent_ctx = trace.set_span_in_context(context.request_span)
+    elif context.current_span:
+      parent_ctx = trace.set_span_in_context(context.current_span)
+    
+    # Create span with parent context if it exists
+    if parent_ctx:
+      span = self.tracer.start_span(
+        name,
+        context=parent_ctx,
+        attributes=attributes
+      )
+    else:
+      span = self.tracer.start_span(
+        name,
+        attributes=attributes
+      )
+    
+    # Update context with current span
+    prev_span = context.current_span
+    context.current_span = span
+    
+    try:
+      start_time = time.perf_counter()
+      yield span
+      duration = time.perf_counter() - start_time
+      span.set_attribute("duration_s", duration)
+      span.set_status(Status(StatusCode.OK))
+    except Exception as e:
+      span.set_status(Status(StatusCode.ERROR, str(e)))
+      raise
+    finally:
+      span.end()
+      context.current_span = prev_span
+
+# Global tracer instance
+tracer = Tracer() 

+ 0 - 0
exo/stats/__init__.py


+ 0 - 27
exo/stats/docker-compose-stats.yml

@@ -1,27 +0,0 @@
-version: '3.8'
-
-services:
-  prometheus:
-    image: prom/prometheus:latest
-    container_name: prometheus
-    volumes:
-      - ./prometheus.yml:/etc/prometheus/prometheus.yml
-    command:
-      - '--config.file=/etc/prometheus/prometheus.yml'
-    ports:
-      - "9090:9090"
-    networks:
-      - monitoring
-
-  grafana:
-    image: grafana/grafana:latest
-    container_name: grafana
-    ports:
-      - "3000:3000"
-    networks:
-      - monitoring
-    depends_on:
-      - prometheus
-
-networks:
-  monitoring:

+ 0 - 29
exo/stats/metrics.py

@@ -1,29 +0,0 @@
-from exo.orchestration import Node
-from prometheus_client import start_http_server, Counter, Histogram
-import json
-
-# Create metrics to track time spent and requests made.
-PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
-PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
-PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])
-
-
-def start_metrics_server(node: Node, port: int):
-  start_http_server(port)
-
-  def _on_opaque_status(request_id, opaque_status: str):
-    status_data = json.loads(opaque_status)
-    _type = status_data.get("type", "")
-    node_id = status_data.get("node_id", "")
-    if _type != "node_status":
-      return
-    status = status_data.get("status", "")
-
-    if status == "end_process_prompt":
-      PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
-    elif status == "end_process_tensor":
-      elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
-      PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
-      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9)  # Convert ns to seconds
-
-  node.on_opaque_status.register("stats").on_next(_on_opaque_status)

+ 0 - 7
exo/stats/prometheus.yml

@@ -1,7 +0,0 @@
-global:
-  scrape_interval: 15s
-
-scrape_configs:
-  - job_name: 'exo-node'
-    static_configs:
-      - targets: ['host.docker.internal:8005']

+ 88 - 0
exo/tinychat/index.css

@@ -654,4 +654,92 @@ main {
 
 .model-download-button i {
   font-size: 0.9em;
+}
+
+.topology-section {
+  margin-bottom: 30px;
+  padding: 15px;
+  background: rgba(255, 255, 255, 0.05);
+  border-radius: 8px;
+}
+
+.topology-visualization {
+  min-height: 150px;
+  position: relative;
+  margin-top: 10px;
+}
+
+.topology-loading {
+  display: flex;
+  align-items: center;
+  gap: 10px;
+  color: #666;
+  font-size: 0.9em;
+}
+
+.topology-node {
+  padding: 8px;
+  background: rgba(255, 255, 255, 0.05);
+  border-radius: 4px;
+  margin: 4px 0;
+  display: flex;
+  flex-direction: column;
+  gap: 4px;
+}
+
+.node-info {
+  display: flex;
+  align-items: center;
+  gap: 6px;
+  font-size: 0.9em;
+}
+
+.topology-node .status {
+  width: 6px;
+  height: 6px;
+  border-radius: 50%;
+  flex-shrink: 0;
+}
+
+.topology-node .status.active {
+  background: #4CAF50;
+}
+
+.topology-node .status.inactive {
+  background: #666;
+}
+
+.node-details {
+  padding-left: 12px;
+  display: flex;
+  flex-direction: column;
+  gap: 2px;
+  font-size: 0.8em;
+  opacity: 0.6;
+}
+
+.node-details span {
+  display: flex;
+  align-items: center;
+}
+
+.peer-connections {
+  margin-top: 8px;
+  padding-left: 12px;
+  display: flex;
+  flex-direction: column;
+  gap: 4px;
+}
+
+.peer-connection {
+  display: flex;
+  align-items: center;
+  gap: 8px;
+  font-size: 0.85em;
+  color: #a0a0a0;
+}
+
+.peer-connection i {
+  font-size: 0.8em;
+  color: #666;
 }

+ 37 - 22
exo/tinychat/index.html

@@ -26,21 +26,36 @@
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
   <div class="sidebar">
+    <!-- Add topology section -->
+    <div class="topology-section">
+      <h2 class="megrim-regular">Network Topology</h2>
+      <div class="topology-visualization"
+           x-init="initTopology()"
+           x-ref="topologyViz">
+        <!-- Loading indicator for topology -->
+        <div class="topology-loading" x-show="!topology">
+          <i class="fas fa-spinner fa-spin"></i>
+          <span>Loading topology...</span>
+        </div>
+        <!-- Topology visualization will be rendered here -->
+      </div>
+    </div>
+
     <h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
-    
+
     <!-- Loading indicator -->
     <div class="loading-container" x-show="Object.keys(models).length === 0">
         <i class="fas fa-spinner fa-spin"></i>
         <span>Loading models...</span>
     </div>
-    
+
     <template x-for="(model, key) in models" :key="key">
-        <div class="model-option" 
+        <div class="model-option"
              :class="{ 'selected': cstate.selectedModel === key }"
              @click="cstate.selectedModel = key">
             <div class="model-header">
                 <div class="model-name" x-text="model.name"></div>
-                <button 
+                <button
                     @click.stop="deleteModel(key, model)"
                     class="model-delete-button"
                     x-show="model.download_percentage > 0">
@@ -56,7 +71,7 @@
                         <template x-if="!model.loading && model.download_percentage != null">
                             <span>
                                 <!-- Check if there's an active download for this model -->
-                                <template x-if="downloadProgress?.some(p => 
+                                <template x-if="downloadProgress?.some(p =>
                                     p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
                                 )">
                                     <i class="fas fa-circle-notch fa-spin"></i>
@@ -65,7 +80,7 @@
                             </span>
                         </template>
                         <template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
-                            <button 
+                            <button
                                 @click.stop="handleDownload(key)"
                                 class="model-download-button">
                                 <i class="fas fa-download"></i>
@@ -75,22 +90,22 @@
                     </div>
                 </div>
                 <template x-if="model.total_size">
-                    <div class="model-size" x-text="model.total_downloaded ? 
-                        `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` : 
+                    <div class="model-size" x-text="model.total_downloaded ?
+                        `${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
                         formatBytes(model.total_size)">
                     </div>
                 </template>
             </div>
         </div>
     </template>
-  </div> 
+  </div>
     <!-- Error Toast -->
     <div x-show="errorMessage !== null" x-transition.opacity class="toast">
         <div class="toast-header">
             <span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
             <div class="toast-header-buttons">
-                <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
-                        class="toast-expand-button" 
+                <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
+                        class="toast-expand-button"
                         x-show="errorMessage?.stack">
                     <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
                 </button>
@@ -119,8 +134,8 @@
     " x-show="home === 0" x-transition="">
 <h1 class="title megrim-regular">tinychat</h1>
 <template x-if="histories.length">
-  <button 
-    @click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();" 
+  <button
+    @click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();"
     class="clear-history-button">
     <i class="fas fa-trash"></i> Clear All History
   </button>
@@ -162,14 +177,14 @@
 </template>
 </div>
 </div>
-<button 
+<button
     @click="
         home = 0;
         cstate = { time: null, messages: [], selectedModel: cstate.selectedModel };
         time_till_first = 0;
         tokens_per_second = 0;
         total_tokens = 0;
-    " 
+    "
     class="back-button"
     x-show="home === 2">
     <i class="fas fa-arrow-left"></i>
@@ -250,7 +265,7 @@
         <p><strong>Model:</strong> <span x-text="progress.repo_id + '@' + progress.repo_revision"></span></p>
         <p><strong>Status:</strong> <span x-text="progress.status"></span></p>
         <div class="progress-bar-container">
-          <div class="progress-bar" 
+          <div class="progress-bar"
                :class="progress.isComplete ? 'complete' : 'in-progress'"
                :style="`width: ${progress.percentage}%;`">
           </div>
@@ -294,10 +309,10 @@
 <i class="fas fa-times"></i>
 </button>
 </div>
-<textarea 
-    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))" 
+<textarea
+    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
     :placeholder="
-        generating ? 'Generating...' : 
+        generating ? 'Generating...' :
         (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete)) ? 'Download in progress...' :
         'Say something'
     "
@@ -329,9 +344,9 @@
         });
     "
     x-ref="inputForm"></textarea>
-<button 
-    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))" 
-    @click="await handleSend()" 
+<button
+    :disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
+    @click="await handleSend()"
     class="input-button">
     <i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
 </button>

+ 78 - 10
exo/tinychat/index.js

@@ -5,7 +5,7 @@ document.addEventListener("alpine:init", () => {
       time: null,
       messages: [],
       selectedModel: 'llama-3.2-1b',
-    },    
+    },
 
     // historical state
     histories: JSON.parse(localStorage.getItem("histories")) || [],
@@ -13,7 +13,7 @@ document.addEventListener("alpine:init", () => {
     home: 0,
     generating: false,
     endpoint: `${window.location.origin}/v1`,
-    
+
     // Initialize error message structure
     errorMessage: null,
     errorExpanded: false,
@@ -39,6 +39,9 @@ document.addEventListener("alpine:init", () => {
     // Add models state alongside existing state
     models: {},
 
+    topology: null,
+    topologyInterval: null,
+
     init() {
       // Clean up any pending messages
       localStorage.removeItem("pendingMessage");
@@ -48,7 +51,7 @@ document.addEventListener("alpine:init", () => {
 
       // Start polling for download progress
       this.startDownloadProgressPolling();
-      
+
       // Start model polling with the new pattern
       this.startModelPolling();
     },
@@ -82,14 +85,14 @@ document.addEventListener("alpine:init", () => {
     async populateSelector() {
       return new Promise((resolve, reject) => {
         const evtSource = new EventSource(`${window.location.origin}/modelpool`);
-        
+
         evtSource.onmessage = (event) => {
           if (event.data === "[DONE]") {
             evtSource.close();
             resolve();
             return;
           }
-          
+
           const modelData = JSON.parse(event.data);
           // Update existing model data while preserving other properties
           Object.entries(modelData).forEach(([modelName, data]) => {
@@ -102,7 +105,7 @@ document.addEventListener("alpine:init", () => {
             }
           });
         };
-        
+
         evtSource.onerror = (error) => {
           console.error('EventSource failed:', error);
           evtSource.close();
@@ -509,7 +512,7 @@ document.addEventListener("alpine:init", () => {
         stack: error.stack || ""
       };
       this.errorExpanded = false;
-      
+
       if (this.errorTimeout) {
         clearTimeout(this.errorTimeout);
       }
@@ -524,10 +527,10 @@ document.addEventListener("alpine:init", () => {
 
     async deleteModel(modelName, model) {
       const downloadedSize = model.total_downloaded || 0;
-      const sizeMessage = downloadedSize > 0 ? 
+      const sizeMessage = downloadedSize > 0 ?
         `This will free up ${this.formatBytes(downloadedSize)} of space.` :
         'This will remove any partially downloaded files.';
-      
+
       if (!confirm(`Are you sure you want to delete ${model.name}? ${sizeMessage}`)) {
         return;
       }
@@ -541,7 +544,7 @@ document.addEventListener("alpine:init", () => {
         });
 
         const data = await response.json();
-        
+
         if (!response.ok) {
           throw new Error(data.detail || 'Failed to delete model');
         }
@@ -600,6 +603,71 @@ document.addEventListener("alpine:init", () => {
         console.error('Error starting download:', error);
         this.setError(error);
       }
+    },
+
+    async fetchTopology() {
+      try {
+        const response = await fetch(`${this.endpoint}/topology`);
+        if (!response.ok) throw new Error('Failed to fetch topology');
+        return await response.json();
+      } catch (error) {
+        console.error('Topology fetch error:', error);
+        return null;
+      }
+    },
+
+    initTopology() {
+      // Initial fetch
+      this.updateTopology();
+
+      // Set up periodic updates
+      this.topologyInterval = setInterval(() => this.updateTopology(), 5000);
+
+      // Cleanup on page unload
+      window.addEventListener('beforeunload', () => {
+        if (this.topologyInterval) {
+          clearInterval(this.topologyInterval);
+        }
+      });
+    },
+
+    async updateTopology() {
+      const topologyData = await this.fetchTopology();
+      if (!topologyData) return;
+
+      const vizElement = this.$refs.topologyViz;
+      vizElement.innerHTML = ''; // Clear existing visualization
+
+      // Create nodes from object
+      Object.entries(topologyData.nodes).forEach(([nodeId, node]) => {
+        const nodeElement = document.createElement('div');
+        nodeElement.className = 'topology-node';
+
+        // Get peer connections for this node
+        const peerConnections = topologyData.peer_graph[nodeId] || [];
+        const peerConnectionsHtml = peerConnections.map(peer => `
+          <div class="peer-connection">
+            <i class="fas fa-arrow-right"></i>
+            <span>To ${peer.to_id}: ${peer.description}</span>
+          </div>
+        `).join('');
+
+        nodeElement.innerHTML = `
+          <div class="node-info">
+            <span class="status ${nodeId === topologyData.active_node_id ? 'active' : 'inactive'}"></span>
+            <span>${node.model}</span>
+          </div>
+          <div class="node-details">
+            <span>${node.chip}</span>
+            <span>${(node.memory / 1024).toFixed(1)}GB RAM</span>
+            <span>${node.flops.fp32.toFixed(1)} TF</span>
+          </div>
+          <div class="peer-connections">
+            ${peerConnectionsHtml}
+          </div>
+        `;
+        vizElement.appendChild(nodeElement);
+      });
     }
   }));
 });

+ 16 - 23
exo/topology/device_capabilities.py

@@ -3,6 +3,8 @@ from pydantic import BaseModel
 from exo import DEBUG
 import subprocess
 import psutil
+import asyncio
+from exo.helpers import get_mac_system_info, subprocess_pool
 
 TFLOPS = 1.00
 
@@ -144,13 +146,13 @@ CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items
 CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()})
 
 
-def device_capabilities() -> DeviceCapabilities:
+async def device_capabilities() -> DeviceCapabilities:
   if psutil.MACOS:
-    return mac_device_capabilities()
+    return await mac_device_capabilities()
   elif psutil.LINUX:
-    return linux_device_capabilities()
+    return await linux_device_capabilities()
   elif psutil.WINDOWS:
-    return windows_device_capabilities()
+    return await windows_device_capabilities()
   else:
     return DeviceCapabilities(
       model="Unknown Device",
@@ -160,27 +162,18 @@ def device_capabilities() -> DeviceCapabilities:
     )
 
 
-def mac_device_capabilities() -> DeviceCapabilities:
-  # Fetch the model of the Mac using system_profiler
-  model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
-  model_line = next((line for line in model.split("\n") if "Model Name" in line), None)
-  model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
-  chip_line = next((line for line in model.split("\n") if "Chip" in line), None)
-  chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
-  memory_line = next((line for line in model.split("\n") if "Memory" in line), None)
-  memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
-  memory_units = memory_str.split()
-  memory_value = int(memory_units[0])
-  if memory_units[1] == "GB":
-    memory = memory_value*1024
-  else:
-    memory = memory_value
-
-  # Assuming static values for other attributes for demonstration
-  return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
+async def mac_device_capabilities() -> DeviceCapabilities:
+  model_id, chip_id, memory = await get_mac_system_info()
+  
+  return DeviceCapabilities(
+    model=model_id,
+    chip=chip_id,
+    memory=memory,
+    flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))
+  )
 
 
-def linux_device_capabilities() -> DeviceCapabilities:
+async def linux_device_capabilities() -> DeviceCapabilities:
   import psutil
   from tinygrad import Device
 

+ 3 - 1
exo/topology/partitioning_strategy.py

@@ -1,8 +1,10 @@
 from abc import ABC, abstractmethod
-from typing import List
+from typing import List, Dict
 from dataclasses import dataclass
 from .topology import Topology
 from exo.inference.shard import Shard
+from exo.topology.device_capabilities import device_capabilities
+import asyncio
 
 
 # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1

+ 41 - 38
exo/topology/test_device_capabilities.py

@@ -1,11 +1,11 @@
-import unittest
+import pytest
 from unittest.mock import patch
-from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS
+from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS, device_capabilities
 
 
-class TestMacDeviceCapabilities(unittest.TestCase):
-  @patch("subprocess.check_output")
-  def test_mac_device_capabilities_pro(self, mock_check_output):
+@pytest.mark.asyncio
+@patch("subprocess.check_output")
+async def test_mac_device_capabilities_pro(mock_check_output):
     # Mock the subprocess output
     mock_check_output.return_value = b"""
 Hardware:
@@ -27,20 +27,19 @@ Activation Lock Status: Enabled
 """
 
     # Call the function
-    result = mac_device_capabilities()
+    result = await mac_device_capabilities()
 
     # Check the results
-    self.assertIsInstance(result, DeviceCapabilities)
-    self.assertEqual(result.model, "MacBook Pro")
-    self.assertEqual(result.chip, "Apple M3 Max")
-    self.assertEqual(result.memory, 131072)  # 16 GB in MB
-    self.assertEqual(
-      str(result),
-      "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
-    )
-
-  @patch("subprocess.check_output")
-  def test_mac_device_capabilities_air(self, mock_check_output):
+    assert isinstance(result, DeviceCapabilities)
+    assert result.model == "MacBook Pro"
+    assert result.chip == "Apple M3 Max"
+    assert result.memory == 131072  # 128 GB in MB
+    assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS"
+
+
+@pytest.mark.asyncio
+@patch("subprocess.check_output")
+async def test_mac_device_capabilities_air(mock_check_output):
     # Mock the subprocess output
     mock_check_output.return_value = b"""
 Hardware:
@@ -62,30 +61,34 @@ Activation Lock Status: Disabled
 """
 
     # Call the function
-    result = mac_device_capabilities()
+    result = await mac_device_capabilities()
 
     # Check the results
-    self.assertIsInstance(result, DeviceCapabilities)
-    self.assertEqual(result.model, "MacBook Air")
-    self.assertEqual(result.chip, "Apple M2")
-    self.assertEqual(result.memory, 8192)  # 8 GB in MB
+    assert isinstance(result, DeviceCapabilities)
+    assert result.model == "MacBook Air"
+    assert result.chip == "Apple M2"
+    assert result.memory == 8192  # 8 GB in MB
+
 
-  @unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
-  def test_mac_device_capabilities_real(self):
+@pytest.mark.skip(reason="Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
+@pytest.mark.asyncio
+async def test_mac_device_capabilities_real():
     # Call the function without mocking
-    result = mac_device_capabilities()
+    result = await mac_device_capabilities()
 
     # Check the results
-    self.assertIsInstance(result, DeviceCapabilities)
-    self.assertEqual(result.model, "MacBook Pro")
-    self.assertEqual(result.chip, "Apple M3 Max")
-    self.assertEqual(result.memory, 131072)  # 128 GB in MB
-    self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
-    self.assertEqual(
-      str(result),
-      "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
-    )
-
-
-if __name__ == "__main__":
-  unittest.main()
+    assert isinstance(result, DeviceCapabilities)
+    assert result.model == "MacBook Pro"
+    assert result.chip == "Apple M3 Max"
+    assert result.memory == 131072  # 128 GB in MB
+    assert result.flops == DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS)
+    assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS"
+
+
+@pytest.mark.asyncio
+async def test_device_capabilities():
+    caps = await device_capabilities()
+    assert caps.model != ""
+    assert caps.chip != ""
+    assert caps.memory > 0
+    assert caps.flops is not None

+ 3 - 3
extra/line_counter.py

@@ -74,9 +74,9 @@ def gen_diff(table_old, table_new):
 
 def create_json_report(table, is_diff=False):
     timestamp = datetime.now(timezone.utc).isoformat()
-    commit_sha = os.environ.get('CIRCLE_SHA1', 'unknown')
-    branch = os.environ.get('CIRCLE_BRANCH', 'unknown')
-    pr_number = os.environ.get('CIRCLE_PR_NUMBER', '')
+    commit_sha = os.environ.get('GITHUB_SHA', 'unknown')
+    branch = os.environ.get('GITHUB_REF_NAME', 'unknown')
+    pr_number = os.environ.get('GITHUB_EVENT_NUMBER', '')
 
     if is_diff:
         files = [{

+ 9 - 4
setup.py

@@ -28,14 +28,19 @@ install_requires = [
   "tqdm==4.66.4",
   "transformers==4.46.3",
   "uuid==1.30",
+  "uvloop==0.21.0",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
 ]
 
 extras_require = {
-  "formatting": ["yapf==0.40.2",], "apple_silicon": [
-    "mlx==0.20.0",
-    "mlx-lm==0.19.3",
-  ], "windows": ["pywin32==308",], "nvidia-gpu": ["nvidia-ml-py==12.560.30",], "amd-gpu": ["pyrsmi==0.2.0"]
+  "formatting": ["yapf==0.40.2",],
+  "apple_silicon": [
+    "mlx==0.21.1",
+    "mlx-lm==0.20.4",
+  ],
+  "windows": ["pywin32==308",],
+  "nvidia-gpu": ["nvidia-ml-py==12.560.30",],
+  "amd-gpu": ["pyrsmi==0.2.0"],
 }
 
 # Check if running on macOS with Apple Silicon

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä