bench.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import aiohttp
  2. import asyncio
  3. import time
  4. import json
  5. import os
  6. async def measure_performance(api_endpoint: str, prompt: str = "Who are you?"):
  7. async with aiohttp.ClientSession() as session:
  8. request = {
  9. "model": "llama-3.2-3b",
  10. "messages": [{"role": "user", "content": prompt}],
  11. "stream": True
  12. }
  13. start_time = time.time()
  14. first_token_time = None
  15. total_tokens = 0
  16. print(f"Sending request to {api_endpoint}...")
  17. async with session.post(api_endpoint, json=request) as response:
  18. async for line in response.content:
  19. if not line.strip():
  20. continue
  21. line = line.decode('utf-8')
  22. if line.startswith('data: '):
  23. line = line[6:] # Remove 'data: ' prefix
  24. if line == '[DONE]':
  25. break
  26. try:
  27. chunk = json.loads(line)
  28. if chunk.get('choices') and chunk['choices'][0].get('delta', {}).get('content'):
  29. if first_token_time is None:
  30. first_token_time = time.time()
  31. ttft = first_token_time - start_time
  32. print(f"Time to first token: {ttft:.3f}s")
  33. total_tokens += 1
  34. except json.JSONDecodeError:
  35. continue
  36. end_time = time.time()
  37. total_time = end_time - start_time
  38. if total_tokens > 0:
  39. tps = total_tokens / total_time
  40. print(f"Tokens per second: {tps:.1f}")
  41. print(f"Total tokens generated: {total_tokens}")
  42. print(f"Total time: {total_time:.3f}s")
  43. else:
  44. print("No tokens were generated")
  45. if __name__ == "__main__":
  46. API_ENDPOINT = os.getenv("API_ENDPOINT", "http://localhost:52415/v1/chat/completions")
  47. asyncio.run(measure_performance(API_ENDPOINT, prompt="Write an essay about life, the universe, and everything."))