bench.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import aiohttp
  2. import asyncio
  3. import time
  4. import json
  5. import os
  6. import boto3
  7. from typing import Dict, Any
  8. from datetime import datetime
  9. import subprocess
  10. def check_gpu_access():
  11. try:
  12. # Check if MLX can see the GPU
  13. import mlx.core as mx
  14. print("MLX device info:", mx.default_device())
  15. # Check Metal device availability
  16. result = subprocess.run(['system_profiler', 'SPDisplaysDataType'], capture_output=True, text=True)
  17. print("GPU Info:", result.stdout)
  18. except Exception as e:
  19. print(f"Failed to check GPU access: {e}")
  20. async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
  21. """
  22. Measures the performance of an API endpoint by sending a prompt and recording metrics.
  23. Args:
  24. api_endpoint (str): The API endpoint URL.
  25. prompt (str): The prompt to send to the API.
  26. Returns:
  27. Dict[str, Any]: A dictionary containing performance metrics or error information.
  28. """
  29. results = {
  30. 'model': model,
  31. 'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
  32. 'branch': os.environ.get('GITHUB_REF_NAME', 'unknown'),
  33. 'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
  34. }
  35. # Get token count
  36. session = aiohttp.ClientSession()
  37. try:
  38. response = await session.post(
  39. "http://localhost:52415/v1/chat/token/encode",
  40. json={
  41. "model": model,
  42. "messages": [{"role": "user", "content": prompt}]
  43. }
  44. )
  45. response.raise_for_status()
  46. token_data = await response.json()
  47. results['prompt_len'] = token_data['num_tokens']
  48. except Exception as e:
  49. await session.close()
  50. raise RuntimeError(f"Failed to get token count: {str(e)}")
  51. # Measure completion performance
  52. try:
  53. start_time = time.time()
  54. response = await session.post(
  55. api_endpoint,
  56. json={
  57. "model": model,
  58. "messages": [{"role": "user", "content": prompt}],
  59. "temperature": 0,
  60. "stream": True
  61. }
  62. )
  63. response.raise_for_status()
  64. first_token_time = None
  65. total_tokens = 0
  66. async for line in response.content.iter_chunks():
  67. line = line[0].decode('utf-8').strip()
  68. if not line.startswith('data: '):
  69. continue
  70. data = json.loads(line[6:]) # Skip 'data: ' prefix
  71. if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
  72. print(f"Received content: {content}", flush=True)
  73. if first_token_time is None:
  74. first_token_time = time.time()
  75. ttft = first_token_time - start_time
  76. results.update({
  77. 'ttft': ttft,
  78. 'prompt_tps': results['prompt_len'] / ttft
  79. })
  80. total_tokens += 1
  81. total_time = time.time() - start_time
  82. results.update({
  83. 'generation_tps': total_tokens / total_time,
  84. 'response_len': total_tokens,
  85. 'total_time': total_time
  86. })
  87. except Exception as e:
  88. raise RuntimeError(f"Performance measurement failed: {str(e)}")
  89. finally:
  90. await session.close()
  91. return results
  92. async def main() -> None:
  93. api_endpoint = "http://localhost:52415/v1/chat/completions"
  94. # Define prompts
  95. prompt_warmup = "what is the capital of France?"
  96. prompt_essay = "write an essay about cats"
  97. model = os.environ.get('model', 'llama-3.2-1b')
  98. # Warmup request
  99. print("\nPerforming warmup request...", flush=True)
  100. try:
  101. warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
  102. print("Warmup completed successfully", flush=True)
  103. except Exception as e:
  104. print(f"Warmup request failed: {e}", flush=True)
  105. # Measure performance for the essay prompt
  106. print("\nMeasuring performance for the essay prompt...", flush=True)
  107. results = await measure_performance(api_endpoint, prompt_essay, model)
  108. try:
  109. s3_client = boto3.client(
  110. 's3',
  111. aws_access_key_id=os.environ.get('aws_access_key_id'),
  112. aws_secret_access_key=os.environ.get('aws_secret_key')
  113. )
  114. job_name = os.environ.get('GITHUB_JOB')
  115. # Create S3 key with timestamp and commit info
  116. now = datetime.utcnow()
  117. timestamp = now.strftime('%H-%M-%S')
  118. commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
  119. s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
  120. # Upload to S3
  121. s3_client.put_object(
  122. Bucket='exo-benchmarks',
  123. Key=s3_key,
  124. Body=json.dumps(results),
  125. ContentType='application/json'
  126. )
  127. print(f"Performance metrics uploaded to S3: s3://exo-benchmarks/{s3_key}", flush=True)
  128. except Exception as e:
  129. print(f"Failed to upload metrics to S3: {e}", flush=True)
  130. # Optionally print the metrics for visibility
  131. print("Performance metrics:", flush=True)
  132. print(json.dumps(results, indent=4), flush=True)
  133. if __name__ == "__main__":
  134. check_gpu_access()
  135. asyncio.run(main())