bench.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import aiohttp
  2. import asyncio
  3. import time
  4. import json
  5. import os
  6. import boto3
  7. from typing import Dict, Any
  8. from datetime import datetime
  9. async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
  10. """
  11. Measures the performance of an API endpoint by sending a prompt and recording metrics.
  12. Args:
  13. api_endpoint (str): The API endpoint URL.
  14. prompt (str): The prompt to send to the API.
  15. Returns:
  16. Dict[str, Any]: A dictionary containing performance metrics or error information.
  17. """
  18. results = {
  19. 'model': model,
  20. 'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
  21. 'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
  22. }
  23. # Get token count
  24. session = aiohttp.ClientSession()
  25. try:
  26. response = await session.post(
  27. "http://localhost:52415/v1/chat/token/encode",
  28. json={
  29. "model": model,
  30. "messages": [{"role": "user", "content": prompt}]
  31. }
  32. )
  33. response.raise_for_status()
  34. token_data = await response.json()
  35. results['prompt_len'] = token_data['num_tokens']
  36. except Exception as e:
  37. await session.close()
  38. raise RuntimeError(f"Failed to get token count: {str(e)}")
  39. # Measure completion performance
  40. try:
  41. start_time = time.time()
  42. response = await session.post(
  43. api_endpoint,
  44. json={
  45. "model": model,
  46. "messages": [{"role": "user", "content": prompt}],
  47. "temperature": 0,
  48. "stream": True
  49. }
  50. )
  51. response.raise_for_status()
  52. first_token_time = None
  53. total_tokens = 0
  54. async for line in response.content.iter_chunks():
  55. line = line[0].decode('utf-8').strip()
  56. if not line.startswith('data: '):
  57. continue
  58. data = json.loads(line[6:]) # Skip 'data: ' prefix
  59. if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
  60. print(f"Received content: {content}", flush=True)
  61. if first_token_time is None:
  62. first_token_time = time.time()
  63. ttft = first_token_time - start_time
  64. results.update({
  65. 'ttft': ttft,
  66. 'prompt_tps': results['prompt_len'] / ttft
  67. })
  68. total_tokens += 1
  69. total_time = time.time() - start_time
  70. results.update({
  71. 'generation_tps': total_tokens / total_time,
  72. 'response_len': total_tokens,
  73. 'total_time': total_time
  74. })
  75. except Exception as e:
  76. raise RuntimeError(f"Performance measurement failed: {str(e)}")
  77. finally:
  78. await session.close()
  79. return results
  80. async def main() -> None:
  81. api_endpoint = "http://localhost:52415/v1/chat/completions"
  82. # Define prompts
  83. prompt_warmup = "what is the capital of France?"
  84. prompt_essay = "write an essay about cats"
  85. model = os.environ.get('model', 'llama-3.2-1b')
  86. # Warmup request
  87. print("\nPerforming warmup request...", flush=True)
  88. try:
  89. warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
  90. print("Warmup completed successfully", flush=True)
  91. except Exception as e:
  92. print(f"Warmup request failed: {e}", flush=True)
  93. # Measure performance for the essay prompt
  94. print("\nMeasuring performance for the essay prompt...", flush=True)
  95. results = await measure_performance(api_endpoint, prompt_essay, model)
  96. try:
  97. s3_client = boto3.client(
  98. 's3',
  99. aws_access_key_id=os.environ.get('aws_access_key_id'),
  100. aws_secret_access_key=os.environ.get('aws_secret_key')
  101. )
  102. job_name = os.environ.get('GITHUB_JOB')
  103. # Create S3 key with timestamp and commit info
  104. now = datetime.utcnow()
  105. timestamp = now.strftime('%H-%M-%S')
  106. commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
  107. s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
  108. # Upload to S3
  109. s3_client.put_object(
  110. Bucket='exo-benchmarks',
  111. Key=s3_key,
  112. Body=json.dumps(results),
  113. ContentType='application/json'
  114. )
  115. print(f"Performance metrics uploaded to S3: s3://exo-benchmarks/{s3_key}", flush=True)
  116. except Exception as e:
  117. print(f"Failed to upload metrics to S3: {e}", flush=True)
  118. # Optionally print the metrics for visibility
  119. print("Performance metrics:", flush=True)
  120. print(json.dumps(results, indent=4), flush=True)
  121. if __name__ == "__main__":
  122. asyncio.run(main())