thenatlog 1 год назад
Родитель
Сommit
b8f245c1ff
2 измененных файлов с 102 добавлено и 37 удалено
  1. 102 28
      .circleci/bench.py
  2. 0 9
      .circleci/config.yml

+ 102 - 28
.circleci/bench.py

@@ -3,9 +3,10 @@ import asyncio
 import time
 import time
 import json
 import json
 import os
 import os
+import subprocess
+import signal
 from typing import Dict, Any
 from typing import Dict, Any
 
 
-
 async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
 async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
     """
     """
     Measures the performance of an API endpoint by sending a prompt and recording metrics.
     Measures the performance of an API endpoint by sending a prompt and recording metrics.
@@ -80,37 +81,110 @@ async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
     return results
     return results
 
 
 
 
-async def main() -> None:
-    api_endpoint = "http://localhost:52415/v1/chat/completions"
-
-    # Define prompts
-    prompt_basic = "hello"
-    prompt_essay = "write an essay about cats"
+async def wait_for_exo(api_endpoint: str, timeout: int = 60) -> bool:
+    """
+    Waits until the Exo API is ready to accept connections.
 
 
-    # Measure performance for the basic prompt
-    print("Measuring performance for the basic prompt...")
-    results_basic = await measure_performance(api_endpoint, prompt_basic)
-    print("Basic prompt performance metrics:")
-    print(json.dumps(results_basic, indent=4))
+    Args:
+        api_endpoint (str): The API endpoint URL.
+        timeout (int): Maximum time to wait in seconds.
 
 
-    # Measure performance for the essay prompt, which depends on the first measurement
-    print("\nMeasuring performance for the essay prompt...")
-    results = await measure_performance(api_endpoint, prompt_essay)
+    Returns:
+        bool: True if Exo is ready, False otherwise.
+    """
+    start_time = time.time()
+    while time.time() - start_time < timeout:
+        try:
+            async with aiohttp.ClientSession() as session:
+                async with session.get(api_endpoint.replace("/v1/chat/completions", "")) as response:
+                    if response.status == 200:
+                        return True
+        except:
+            pass
+        await asyncio.sleep(2)  # Wait before retrying
+    return False
 
 
-    # Save metrics from the "universe and everything" prompt
-    metrics_file = os.path.join("artifacts", "benchmark.json")
-    os.makedirs(os.path.dirname(metrics_file), exist_ok=True)
-    try:
-        with open(metrics_file, "w", encoding="utf-8") as f:
-            json.dump(results, f, indent=4)
-        print(f"Performance metrics saved to {metrics_file}")
-    except IOError as e:
-        print(f"Failed to save metrics: {e}")
 
 
-    # Optionally print the metrics for visibility
-    print("Performance metrics:")
-    print(json.dumps(results, indent=4))
+async def main() -> None:
+    exo_command = [
+        "/opt/homebrew/bin/python3.12",
+        "-m",
+        "venv",
+        "venv"
+    ]
+    # Initialize virtual environment
+    print("Setting up virtual environment...")
+    subprocess.run(exo_command, check=True)
+
+    # Activate virtual environment and install dependencies
+    activate_command = "source venv/bin/activate && pip install -U pip && pip install -e ."
+    print("Installing dependencies...")
+    subprocess.run(activate_command, shell=True, check=True)
+
+    # Start Exo as a subprocess
+    print("Starting Exo...")
+    exo_process = subprocess.Popen(
+        ["venv/bin/exo", "run", "llama-3.2-3b", "--prompt", "hello"],
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE,
+        preexec_fn=os.setsid  # To allow killing the entire process group
+    )
 
 
+    try:
+        # Wait for Exo to be ready
+        api_endpoint = "http://localhost:52415/v1/chat/completions"
+        print("Waiting for Exo to initialize...")
+        is_ready = await wait_for_exo(api_endpoint)
+        if not is_ready:
+            raise RuntimeError("Exo did not initialize within the expected time.")
+
+        # Define prompts
+        prompt_basic = "hello"
+        prompt_essay = "write an essay about cats"
+
+        # Measure performance for the basic prompt
+        print("Measuring performance for the basic prompt...")
+        results_basic = await measure_performance(api_endpoint, prompt_basic)
+        print("Basic prompt performance metrics:")
+        print(json.dumps(results_basic, indent=4))
+
+        # Measure performance for the essay prompt, which depends on the first measurement
+        print("\nMeasuring performance for the essay prompt...")
+        results_essay = await measure_performance(api_endpoint, prompt_essay)
+
+        # Combine results
+        combined_results = {
+            "basic_prompt": results_basic,
+            "essay_prompt": results_essay
+        }
+
+        # Save metrics to artifacts
+        metrics_file = os.path.join("artifacts", "benchmark.json")
+        os.makedirs(os.path.dirname(metrics_file), exist_ok=True)
+        try:
+            with open(metrics_file, "w", encoding="utf-8") as f:
+                json.dump(combined_results, f, indent=4)
+            print(f"Performance metrics saved to {metrics_file}")
+        except IOError as e:
+            print(f"Failed to save metrics: {e}")
+
+        # Optionally print the metrics for visibility
+        print("Performance metrics:")
+        print(json.dumps(combined_results, indent=4))
+
+    except Exception as e:
+        print(f"An error occurred during benchmarking: {e}")
+    finally:
+        # Terminate Exo subprocess
+        print("Shutting down Exo...")
+        try:
+            os.killpg(os.getpgid(exo_process.pid), signal.SIGTERM)
+            exo_process.wait(timeout=30)
+            print("Exo shut down successfully.")
+        except Exception as e:
+            print(f"Failed to terminate Exo gracefully: {e}")
+            exo_process.kill()
+            print("Exo was forcefully terminated.")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    asyncio.run(main()) 
+    asyncio.run(main())

+ 0 - 9
.circleci/config.yml

@@ -18,16 +18,7 @@ jobs:
             source venv/bin/activate
             source venv/bin/activate
             pip install -U pip
             pip install -U pip
             pip install -e .
             pip install -e .
-            # Start exo as a background process and redirect output to a log file
-            exo run llama-3.2-3b --prompt "hello" > exo.log 2>&1 &
-            EXO_PID=$!
-            echo "Started exo with PID $EXO_PID"
-            # Wait for exo to initialize
-            sleep 10
-            # Run the benchmark
             python3 .circleci/bench.py
             python3 .circleci/bench.py
-            # Stop exo after the benchmark completes
-            kill $EXO_PID
       - store_artifacts:
       - store_artifacts:
           path: artifacts/benchmark.json
           path: artifacts/benchmark.json
           destination: benchmark.json
           destination: benchmark.json