bench.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import aiohttp
  2. import asyncio
  3. import time
  4. import json
  5. import os
  6. from typing import Dict, Any
  7. async def measure_performance(api_endpoint: str, prompt: str) -> Dict[str, Any]:
  8. """
  9. Measures the performance of an API endpoint by sending a prompt and recording metrics.
  10. Args:
  11. api_endpoint (str): The API endpoint URL.
  12. prompt (str): The prompt to send to the API.
  13. Returns:
  14. Dict[str, Any]: A dictionary containing performance metrics or error information.
  15. """
  16. results: Dict[str, Any] = {}
  17. request_payload = {
  18. "model": "llama-3.2-1b",
  19. "messages": [{"role": "user", "content": prompt}],
  20. "temperature": 0,
  21. "stream": True
  22. }
  23. async with aiohttp.ClientSession() as session:
  24. try:
  25. start_time = time.time()
  26. first_token_time = None
  27. total_tokens = 0
  28. async with session.post(api_endpoint, json=request_payload) as response:
  29. if response.status != 200:
  30. results["error"] = f"HTTP {response.status}: {response.reason}"
  31. return results
  32. async for raw_line in response.content:
  33. line = raw_line.decode('utf-8').strip()
  34. if not line or not line.startswith('data: '):
  35. continue
  36. line_content = line[6:] # Remove 'data: ' prefix
  37. if line_content == '[DONE]':
  38. break
  39. try:
  40. chunk = json.loads(line_content)
  41. choice = chunk.get('choices', [{}])[0]
  42. content = choice.get('delta', {}).get('content')
  43. if content:
  44. if first_token_time is None:
  45. first_token_time = time.time()
  46. results["time_to_first_token"] = first_token_time - start_time
  47. total_tokens += 1
  48. except json.JSONDecodeError:
  49. # Log or handle malformed JSON if necessary
  50. continue
  51. end_time = time.time()
  52. total_time = end_time - start_time
  53. if total_tokens > 0:
  54. results.update({
  55. "tokens_per_second": total_tokens / total_time,
  56. "total_tokens": total_tokens,
  57. "total_time": total_time
  58. })
  59. else:
  60. results["error"] = "No tokens were generated"
  61. except aiohttp.ClientError as e:
  62. results["error"] = f"Client error: {e}"
  63. except Exception as e:
  64. results["error"] = f"Unexpected error: {e}"
  65. return results
  66. async def main() -> None:
  67. api_endpoint = "http://localhost:52415/v1/chat/completions"
  68. # Define prompts
  69. prompt_basic = "this is a ping"
  70. prompt_essay = "write an essay about cats"
  71. # Measure performance for the basic prompt
  72. print("Measuring performance for the basic prompt...")
  73. results_basic = await measure_performance(api_endpoint, prompt_basic)
  74. print("Basic prompt performance metrics:")
  75. print(json.dumps(results_basic, indent=4))
  76. # Measure performance for the essay prompt, which depends on the first measurement
  77. print("\nMeasuring performance for the essay prompt...")
  78. results = await measure_performance(api_endpoint, prompt_essay)
  79. # Save metrics from the "universe and everything" prompt
  80. metrics_file = os.path.join("artifacts", "benchmark.json")
  81. os.makedirs(os.path.dirname(metrics_file), exist_ok=True)
  82. try:
  83. with open(metrics_file, "w", encoding="utf-8") as f:
  84. json.dump(results, f, indent=4)
  85. print(f"Performance metrics saved to {metrics_file}")
  86. except IOError as e:
  87. print(f"Failed to save metrics: {e}")
  88. # Optionally print the metrics for visibility
  89. print("Performance metrics:")
  90. print(json.dumps(results, indent=4))
  91. if __name__ == "__main__":
  92. asyncio.run(main())