Browse Source

model matrix

Glen 7 months ago
parent
commit
c8f93721c5
3 changed files with 12 additions and 7 deletions
  1. 5 5
      .github/bench.py
  2. 1 1
      .github/workflows/bench_job.yml
  3. 6 1
      .github/workflows/benchmarks.yml

+ 5 - 5
.github/bench.py

@@ -8,7 +8,7 @@ from typing import Dict, Any
 from datetime import datetime
 from datetime import datetime
 
 
 
 
-async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
+async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
     """
     """
     Measures the performance of an API endpoint by sending a prompt and recording metrics.
     Measures the performance of an API endpoint by sending a prompt and recording metrics.
 
 
@@ -19,7 +19,6 @@ async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
     Returns:
     Returns:
         Dict[str, Any]: A dictionary containing performance metrics or error information.
         Dict[str, Any]: A dictionary containing performance metrics or error information.
     """
     """
-    model = os.environ.get('model', 'llama-3.2-1b')
 
 
     results = {
     results = {
         'model': model,
         'model': model,
@@ -100,17 +99,18 @@ async def main() -> None:
     prompt_warmup = "what is the capital of France?"
     prompt_warmup = "what is the capital of France?"
     prompt_essay = "write an essay about cats"
     prompt_essay = "write an essay about cats"
 
 
+    model = os.environ.get('model', 'llama-3.2-1b')
     # Warmup request
     # Warmup request
     print("\nPerforming warmup request...", flush=True)
     print("\nPerforming warmup request...", flush=True)
     try:
     try:
-        warmup_results = await measure_performance(api_endpoint, prompt_warmup)
+        warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
         print("Warmup completed successfully", flush=True)
         print("Warmup completed successfully", flush=True)
     except Exception as e:
     except Exception as e:
         print(f"Warmup request failed: {e}", flush=True)
         print(f"Warmup request failed: {e}", flush=True)
 
 
     # Measure performance for the essay prompt
     # Measure performance for the essay prompt
     print("\nMeasuring performance for the essay prompt...", flush=True)
     print("\nMeasuring performance for the essay prompt...", flush=True)
-    results = await measure_performance(api_endpoint, prompt_essay)
+    results = await measure_performance(api_endpoint, prompt_essay, model)
 
 
     try:
     try:
         s3_client = boto3.client(
         s3_client = boto3.client(
@@ -124,7 +124,7 @@ async def main() -> None:
         now = datetime.utcnow()
         now = datetime.utcnow()
         timestamp = now.strftime('%H-%M-%S')
         timestamp = now.strftime('%H-%M-%S')
         commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
         commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
-        s3_key = f"{job_name}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
+        s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
 
 
         # Upload to S3
         # Upload to S3
         s3_client.put_object(
         s3_client.put_object(

+ 1 - 1
.github/workflows/bench_job.yml

@@ -62,7 +62,7 @@ jobs:
           ps aux | grep exo || true
           ps aux | grep exo || true
 
 
           CALLING_JOB="${{ inputs.calling_job_name }}"
           CALLING_JOB="${{ inputs.calling_job_name }}"
-          UNIQUE_JOB_ID="${CALLING_JOB}_${GITHUB_RUN_ID}"
+          UNIQUE_JOB_ID="${CALLING_JOB}_${model}_${GITHUB_RUN_ID}"
           ALL_NODE_IDS=$(for i in $(seq ${{ strategy.job-total }} -1 0); do echo -n "${UNIQUE_JOB_ID}_${i},"; done | sed 's/,$//')
           ALL_NODE_IDS=$(for i in $(seq ${{ strategy.job-total }} -1 0); do echo -n "${UNIQUE_JOB_ID}_${i},"; done | sed 's/,$//')
           MY_NODE_ID="${UNIQUE_JOB_ID}_${{ strategy.job-index }}"
           MY_NODE_ID="${UNIQUE_JOB_ID}_${{ strategy.job-index }}"
           source env/bin/activate
           source env/bin/activate

+ 6 - 1
.github/workflows/benchmarks.yml

@@ -9,9 +9,14 @@ on:
 
 
 jobs:
 jobs:
   test-m4-cluster:
   test-m4-cluster:
+    strategy:
+      matrix:
+        model: ['llama-3.2-1b', 'llama-3.2-3b']
+      # Optional: add fail-fast: false if you want all matrix jobs to continue even if one fails
+      fail-fast: false
     uses: ./.github/workflows/bench_job.yml
     uses: ./.github/workflows/bench_job.yml
     with:
     with:
       config: '{"M4PRO_GPU16_24GB": 2}'
       config: '{"M4PRO_GPU16_24GB": 2}'
-      model: 'llama-3.2-1b'
+      model: ${{ matrix.model }}
       calling_job_name: 'test-m4-cluster'
       calling_job_name: 'test-m4-cluster'
     secrets: inherit
     secrets: inherit