Browse Source

simplify bench

Alex Cheema 7 months ago
parent
commit
fb44eb086c
1 changed files with 69 additions and 81 deletions
  1. 69 81
      .github/bench.py

+ 69 - 81
.github/bench.py

@@ -19,88 +19,76 @@ async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
     Returns:
         Dict[str, Any]: A dictionary containing performance metrics or error information.
     """
-    model = os.environ.get('model')
-    results: Dict[str, Any] = {'model': model, 'run_id': os.environ.get('GITHUB_RUN_ID')}
-    results['configuration'] = json.loads(os.environ.get('HARDWARE_CONFIG'))
-
-    # Get prompt length in tokens
-    async with aiohttp.ClientSession() as session:
-        try:
-            request_payload = {
+    model = os.environ.get('model', 'llama-3.2-1b')
+
+    results = {
+        'model': model,
+        'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
+        'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
+    }
+
+    # Get token count
+    session = aiohttp.ClientSession()
+    try:
+        response = await session.post(
+            "http://localhost:52415/v1/chat/token/encode",
+            json={
                 "model": model,
                 "messages": [{"role": "user", "content": prompt}]
             }
-            async with session.post(
-                "http://localhost:52415/v1/chat/token/encode",
-                json=request_payload
-            ) as response:
-                token_data = await response.json()
-                prompt_tokens = token_data.get('num_tokens', 0)
-                print(f"Prompt length: {prompt_tokens} tokens", flush=True)
-        except Exception as e:
-            print(f"Failed to get prompt length: {e}", flush=True)
-            prompt_tokens = 0
-    results['prompt_len'] = prompt_tokens
-
-    request_payload = {
-        "model": model,
-        "messages": [{"role": "user", "content": prompt}],
-        "temperature": 0,
-        "stream": True
-    }
+        )
+        response.raise_for_status()
+        token_data = await response.json()
+        results['prompt_len'] = token_data['num_tokens']
+    except Exception as e:
+        await session.close()
+        raise RuntimeError(f"Failed to get token count: {str(e)}")
 
-    async with aiohttp.ClientSession() as session:
-        try:
-            start_time = time.time()
-            first_token_time = None
-            total_tokens = 0
-
-            async with session.post(api_endpoint, json=request_payload) as response:
-                if response.status != 200:
-                    results["error"] = f"HTTP {response.status}: {response.reason}"
-                    return results
-
-                async for raw_line in response.content:
-                    line = raw_line.decode('utf-8').strip()
-                    if not line or not line.startswith('data: '):
-                        continue
-
-                    line_content = line[6:]  # Remove 'data: ' prefix
-                    if line_content == '[DONE]':
-                        break
-
-                    try:
-                        chunk = json.loads(line_content)
-                        choice = chunk.get('choices', [{}])[0]
-                        content = choice.get('delta', {}).get('content')
-
-                        if content:
-                            if first_token_time is None:
-                                first_token_time = time.time()
-                                results['ttft'] = first_token_time - start_time
-                                results['prompt_tps'] = prompt_tokens/results['ttft']
-
-                            total_tokens += 1
-                    except json.JSONDecodeError:
-                        # Log or handle malformed JSON if necessary
-                        continue
-
-            end_time = time.time()
-            total_time = end_time - start_time
-
-            if total_tokens > 0:
-                results.update({
-                    "generation_tps": total_tokens / total_time,
-                    "response_len": total_tokens,
-                    "total_time": total_time
-                })
-            else:
-                results["error"] = "No tokens were generated"
-
-        except aiohttp.ClientError as e:
-            results["error"] = f"Client error: {e}"
-        except Exception as e:
-            results["error"] = f"Unexpected error: {e}"
+    # Measure completion performance
+    try:
+        start_time = time.time()
+        response = await session.post(
+            api_endpoint,
+            json={
+                "model": model,
+                "messages": [{"role": "user", "content": prompt}],
+                "temperature": 0,
+                "stream": True
+            }
+        )
+        response.raise_for_status()
+
+        first_token_time = None
+        total_tokens = 0
+
+        async for line in response.content.iter_chunks():
+            line = line[0].decode('utf-8').strip()
+            if not line.startswith('data: '):
+                continue
+
+            data = json.loads(line[6:])  # Skip 'data: ' prefix
+            if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
+                print(f"Received content: {content}", flush=True)
+                if first_token_time is None:
+                    first_token_time = time.time()
+                    ttft = first_token_time - start_time
+                    results.update({
+                        'ttft': ttft,
+                        'prompt_tps': results['prompt_len'] / ttft
+                    })
+                total_tokens += 1
+
+        total_time = time.time() - start_time
+        results.update({
+            'generation_tps': total_tokens / total_time,
+            'response_len': total_tokens,
+            'total_time': total_time
+        })
+
+    except Exception as e:
+        raise RuntimeError(f"Performance measurement failed: {str(e)}")
+    finally:
+        await session.close()
 
     return results
 
@@ -122,13 +110,13 @@ async def main() -> None:
             aws_secret_access_key=os.environ.get('aws_secret_key')
         )
         job_name = os.environ.get('GITHUB_JOB')
-        
+
         # Create S3 key with timestamp and commit info
         now = datetime.utcnow()
         timestamp = now.strftime('%H-%M-%S')
         commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
         s3_key = f"{job_name}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
-        
+
         # Upload to S3
         s3_client.put_object(
             Bucket='exo-benchmarks',
@@ -146,4 +134,4 @@ async def main() -> None:
 
 
 if __name__ == "__main__":
-    asyncio.run(main()) 
+    asyncio.run(main())