dashboard.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import dash
  2. from dash import html, dcc, ctx
  3. import plotly.graph_objs as go
  4. from dash.dependencies import Input, Output, State
  5. import aioboto3
  6. import asyncio
  7. from aiohttp import ClientSession
  8. import json
  9. from collections import defaultdict
  10. import os
  11. import base64
  12. import numpy as np
  13. from plotly.subplots import make_subplots
  14. import plotly.express as px
  15. import aiohttp
  16. from datetime import datetime
  17. # Replace boto3 client with aioboto3 session
  18. session = aioboto3.Session()
  19. BUCKET_NAME = 'exo-benchmarks'
  20. DISCORD_WEBHOOK_URL = os.getenv('DISCORD_WEBHOOK_URL')
  21. CURSOR_KEY = 'last_processed_timestamp.txt'
  22. def load_mock_data():
  23. current_dir = os.path.dirname(os.path.abspath(__file__))
  24. mock_data_path = os.path.join(current_dir, 'mock_data.json')
  25. with open(mock_data_path, 'r') as f:
  26. return json.load(f)
  27. async def load_data_from_s3():
  28. # For testing, use mock data if environment variable is set
  29. if os.getenv('USE_MOCK_DATA'):
  30. return load_mock_data()
  31. config_data = defaultdict(list)
  32. async with session.client('s3') as s3:
  33. paginator = s3.get_paginator('list_objects_v2')
  34. objects_to_fetch = []
  35. # First, get all object keys
  36. async for page in paginator.paginate(Bucket=BUCKET_NAME):
  37. for obj in page.get('Contents', []):
  38. key = obj['Key']
  39. key_parts = key.split('/')
  40. if len(key_parts) < 2:
  41. continue
  42. objects_to_fetch.append((key, obj['LastModified'], f"{key_parts[0]}/{key_parts[1]}"))
  43. # Then fetch all objects in parallel
  44. async def fetch_object(key, last_modified, config_name):
  45. response = await s3.get_object(Bucket=BUCKET_NAME, Key=key)
  46. body = await response['Body'].read()
  47. data = json.loads(body.decode('utf-8'))
  48. print(f"Processing object: {key}: {data}")
  49. return {
  50. 'config_name': config_name,
  51. 'data': {
  52. 'timestamp': data.get('timestamp', last_modified.strftime('%Y-%m-%dT%H:%M:%S')),
  53. 'prompt_tps': data.get('prompt_tps', 0),
  54. 'generation_tps': data.get('generation_tps', 0),
  55. 'commit': data.get('commit', ''),
  56. 'run_id': data.get('run_id', ''),
  57. 'model': data.get('model', ''),
  58. 'branch': data.get('branch', ''),
  59. 'configuration': data.get('configuration', {}),
  60. 'prompt_len': data.get('prompt_len', 0),
  61. 'ttft': data.get('ttft', 0),
  62. 'response_len': data.get('response_len', 0),
  63. 'total_time': data.get('total_time', 0)
  64. }
  65. }
  66. # Create tasks for all objects
  67. tasks = [fetch_object(key, last_modified, config_name)
  68. for key, last_modified, config_name in objects_to_fetch]
  69. results = await asyncio.gather(*tasks)
  70. # Organize results into config_data
  71. for result in results:
  72. config_data[result['config_name']].append(result['data'])
  73. # Sort data by timestamp for each config
  74. for config in config_data:
  75. config_data[config].sort(key=lambda x: x['timestamp'])
  76. return config_data
  77. async def get_best_benchmarks():
  78. config_data = await load_data_from_s3()
  79. best_results = {}
  80. for config_name, data in config_data.items():
  81. if not data:
  82. continue
  83. # Split config_name into config and model
  84. config, model = config_name.split('/')
  85. # Find the entry with the highest generation_tps
  86. best_result = max(data, key=lambda x: x['generation_tps'])
  87. # Create result dictionary with all original data plus config/model info
  88. result = dict(best_result) # Make a copy of all data from the best run
  89. result.update({
  90. 'config': config,
  91. 'model': model,
  92. })
  93. best_results[config_name] = result
  94. return best_results
  95. async def get_previous_benchmark(config_data, config_name, current_timestamp):
  96. """Get the previous benchmark for a given configuration."""
  97. benchmarks = config_data.get(config_name, [])
  98. # Sort by timestamp and find the most recent benchmark before current_timestamp
  99. previous = None
  100. for b in sorted(benchmarks, key=lambda x: x['timestamp']):
  101. if b['timestamp'] < current_timestamp:
  102. previous = b
  103. else:
  104. break
  105. return previous
  106. async def format_metric_comparison(current, previous, metric, format_str=".2f", lower_is_better=False):
  107. """Format a metric with trend indicator."""
  108. current_val = current.get(metric, 0)
  109. if not previous:
  110. return f"**{current_val:{format_str}}**"
  111. prev_val = previous.get(metric, 0)
  112. diff = current_val - prev_val
  113. # Invert the comparison logic if lower values are better
  114. if lower_is_better:
  115. diff = -diff # This makes negative diffs good and positive diffs bad
  116. if diff > 0:
  117. return f"**{current_val:{format_str}}** 🟢↑ ({'-' if lower_is_better else '+'}{abs(current_val - prev_val):{format_str}})"
  118. elif diff < 0:
  119. return f"**{current_val:{format_str}}** 🔴↓ ({'+' if lower_is_better else '-'}{abs(current_val - prev_val):{format_str}})"
  120. else:
  121. return f"**{current_val:{format_str}}** ⚪"
  122. async def send_discord_notification(benchmark_data, config_data):
  123. if not DISCORD_WEBHOOK_URL:
  124. print("Discord webhook URL not configured, skipping notification")
  125. return
  126. # Create a formatted message
  127. config_name = f"{benchmark_data['config']}/{benchmark_data['model']}"
  128. # Use the passed config_data instead of fetching again
  129. previous_benchmark = await get_previous_benchmark(
  130. config_data,
  131. f"{benchmark_data['config']}/{benchmark_data['model']}",
  132. benchmark_data['timestamp']
  133. )
  134. # Format metrics with comparisons
  135. gen_tps = await format_metric_comparison(benchmark_data, previous_benchmark, 'generation_tps')
  136. prompt_tps = await format_metric_comparison(benchmark_data, previous_benchmark, 'prompt_tps')
  137. ttft = await format_metric_comparison(
  138. {'ttft': benchmark_data['ttft'] * 1000},
  139. {'ttft': previous_benchmark['ttft'] * 1000} if previous_benchmark else None,
  140. 'ttft',
  141. lower_is_better=True
  142. )
  143. prompt_len = await format_metric_comparison(benchmark_data, previous_benchmark, 'prompt_len', "d")
  144. response_len = await format_metric_comparison(benchmark_data, previous_benchmark, 'response_len', "d")
  145. # Create a simple JSON string of the topology
  146. topology = benchmark_data.get('configuration', {})
  147. topology_str = "```json\n" + json.dumps(topology, indent=2) + "\n```"
  148. message = (
  149. f"🚀 New Benchmark Result for **{config_name}**\n\n"
  150. f"📊 Performance Metrics:\n"
  151. f"• Generation TPS: {gen_tps}\n"
  152. f"• Prompt TPS: {prompt_tps}\n"
  153. f"• TTFT: {ttft}ms\n"
  154. f"• Prompt Length: {prompt_len}\n"
  155. f"• Response Length: {response_len}\n\n"
  156. f"🔍 Run Details:\n"
  157. f"• Commit: {benchmark_data['commit'][:7]}\n"
  158. f"• Branch: {benchmark_data['branch']}\n"
  159. f"• Run ID: [{benchmark_data['run_id']}](https://github.com/exo-explore/exo/actions/runs/{benchmark_data['run_id']})\n\n"
  160. f"{topology_str}"
  161. )
  162. async with aiohttp.ClientSession() as session:
  163. await session.post(DISCORD_WEBHOOK_URL, json={'content': message})
  164. async def get_cursor():
  165. try:
  166. async with session.client('s3') as s3:
  167. response = await s3.get_object(Bucket=BUCKET_NAME, Key=CURSOR_KEY)
  168. body = await response['Body'].read()
  169. return body.decode('utf-8').strip()
  170. except:
  171. return "1970-01-01T00:00:00" # Default to epoch if no cursor exists
  172. async def update_cursor(timestamp):
  173. async with session.client('s3') as s3:
  174. await s3.put_object(
  175. Bucket=BUCKET_NAME,
  176. Key=CURSOR_KEY,
  177. Body=timestamp.encode('utf-8')
  178. )
  179. async def generate_best():
  180. # Get the last processed timestamp
  181. last_processed = await get_cursor()
  182. print(f"Last processed timestamp: {last_processed}")
  183. async with session.client('s3') as s3:
  184. # Load all benchmark data once
  185. config_data = await load_data_from_s3()
  186. best_benchmarks = await get_best_benchmarks()
  187. # Check for new benchmarks in all data
  188. new_latest = last_processed
  189. for config_name, data_list in config_data.items():
  190. for benchmark in data_list:
  191. timestamp = benchmark['timestamp']
  192. # If this benchmark is newer than our last processed timestamp
  193. if timestamp > last_processed:
  194. print(f"Found new benchmark for {config_name} at {timestamp}")
  195. # Add config and model info to the benchmark data
  196. config, model = config_name.split('/')
  197. benchmark_with_info = dict(benchmark)
  198. benchmark_with_info.update({
  199. 'config': config,
  200. 'model': model,
  201. })
  202. # Pass the already loaded config_data to avoid refetching
  203. await send_discord_notification(benchmark_with_info, config_data)
  204. # Update the latest timestamp if this is the newest we've seen
  205. if timestamp > new_latest:
  206. new_latest = timestamp
  207. # Update the cursor if we found any new benchmarks
  208. if new_latest > last_processed:
  209. await update_cursor(new_latest)
  210. # Upload the best benchmarks as before
  211. try:
  212. await s3.put_object(
  213. Bucket=BUCKET_NAME,
  214. Key='best.json',
  215. Body=json.dumps(best_benchmarks, indent=2),
  216. ContentType='application/json'
  217. )
  218. print("Successfully uploaded best.json to S3")
  219. print(f"Public URL: https://{BUCKET_NAME}.s3.amazonaws.com/best.json")
  220. except Exception as e:
  221. print(f"Error uploading to S3: {e}")
  222. app = dash.Dash(__name__)
  223. app.layout = html.Div([
  224. html.H1('Benchmark Performance Dashboard'),
  225. html.Button('Test Sound', id='test-sound-button', n_clicks=0),
  226. html.Div(id='graphs-container'),
  227. html.Audio(id='success-sound', src='assets/pokemon_evolve.mp3', preload="auto", style={'display': 'none'}),
  228. html.Audio(id='failure-sound', src='assets/gta5_wasted.mp3', preload="auto", style={'display': 'none'}),
  229. html.Audio(id='startup-sound', src='assets/pokemon_evolve.mp3', preload="auto", style={'display': 'none'}),
  230. html.Div(id='audio-trigger', style={'display': 'none'}),
  231. dcc.Store(id='previous-data', storage_type='memory'),
  232. dcc.Interval(
  233. id='interval-component',
  234. interval=15000, # Update every 15 seconds
  235. n_intervals=0
  236. )
  237. ])
  238. @app.callback(
  239. [Output('graphs-container', 'children'),
  240. Output('previous-data', 'data'),
  241. Output('audio-trigger', 'children')],
  242. [Input('interval-component', 'n_intervals')],
  243. [State('previous-data', 'data')]
  244. )
  245. def update_graphs(n, previous_data):
  246. # Run async operations synchronously
  247. config_data = asyncio.run(load_data_from_s3())
  248. graphs = []
  249. trigger_sound = None
  250. if previous_data:
  251. for config_name, data in config_data.items():
  252. if config_name in previous_data and data and previous_data[config_name]:
  253. current_generation_tps = data[-1]['generation_tps']
  254. previous_generation_tps = previous_data[config_name][-1]['generation_tps']
  255. # Add clear logging for TPS changes
  256. if current_generation_tps != previous_generation_tps:
  257. print("\n" + "="*50)
  258. print(f"Config: {config_name}")
  259. print(f"Previous Generation TPS: {previous_generation_tps}")
  260. print(f"Current Generation TPS: {current_generation_tps}")
  261. print(f"Change: {current_generation_tps - previous_generation_tps}")
  262. if current_generation_tps > previous_generation_tps:
  263. print("🔼 Generation TPS INCREASED - Should play success sound")
  264. trigger_sound = 'success'
  265. elif current_generation_tps < previous_generation_tps:
  266. print("🔽 Generation TPS DECREASED - Should play failure sound")
  267. trigger_sound = 'failure'
  268. if current_generation_tps != previous_generation_tps:
  269. print("="*50 + "\n")
  270. for config_name, data in config_data.items():
  271. timestamps = [d['timestamp'] for d in data]
  272. generation_tps = [d['generation_tps'] for d in data]
  273. commits = [d['commit'] for d in data]
  274. run_ids = [d['run_id'] for d in data]
  275. # Create a list of unique branches for this config
  276. branches = list(set(d['branch'] for d in data))
  277. # Create subplot with 2 columns
  278. fig = make_subplots(rows=1, cols=2,
  279. subplot_titles=('Performance Over Time', 'Generation TPS Distribution'),
  280. column_widths=[0.7, 0.3])
  281. # Generate a color for each branch
  282. colors = px.colors.qualitative.Set1[:len(branches)]
  283. branch_colors = dict(zip(branches, colors))
  284. # Time series plot (left) - separate line for each branch
  285. for branch in branches:
  286. branch_data = [d for d in data if d['branch'] == branch]
  287. branch_timestamps = [d['timestamp'] for d in branch_data]
  288. branch_generation_tps = [d['generation_tps'] for d in branch_data]
  289. branch_commits = [d['commit'] for d in branch_data]
  290. branch_run_ids = [d['run_id'] for d in branch_data]
  291. fig.add_trace(go.Scatter(
  292. x=branch_timestamps,
  293. y=branch_generation_tps,
  294. name=f'{branch}',
  295. mode='lines+markers',
  296. hovertemplate='Branch: %{text}<br>Commit: %{customdata}<br>TPS: %{y}<extra></extra>',
  297. text=[branch] * len(branch_timestamps),
  298. customdata=branch_commits,
  299. line=dict(color=branch_colors[branch], width=2),
  300. marker=dict(color=branch_colors[branch])
  301. ), row=1, col=1)
  302. # Histogram plot (right) - stacked histogram by branch
  303. for branch in branches:
  304. branch_data = [d for d in data if d['branch'] == branch]
  305. branch_generation_tps = [d['generation_tps'] for d in branch_data]
  306. fig.add_trace(go.Histogram(
  307. x=branch_generation_tps,
  308. name=f'{branch}',
  309. nbinsx=10,
  310. marker=dict(color=branch_colors[branch]),
  311. opacity=0.75
  312. ), row=1, col=2)
  313. # Calculate statistics for all data
  314. gen_tps_array = np.array(generation_tps)
  315. stats = {
  316. 'Mean': np.mean(gen_tps_array),
  317. 'Std Dev': np.std(gen_tps_array),
  318. 'Min': np.min(gen_tps_array),
  319. 'Max': np.max(gen_tps_array)
  320. }
  321. # Add statistics as annotations
  322. stats_text = '<br>'.join([f'{k}: {v:.2f}' for k, v in stats.items()])
  323. fig.add_annotation(
  324. x=0.98,
  325. y=0.98,
  326. xref='paper',
  327. yref='paper',
  328. text=stats_text,
  329. showarrow=False,
  330. font=dict(size=12),
  331. align='left',
  332. bgcolor='rgba(255, 255, 255, 0.8)',
  333. bordercolor='black',
  334. borderwidth=1
  335. )
  336. fig.update_layout(
  337. title=f'Performance Metrics - {config_name}',
  338. height=500,
  339. showlegend=True,
  340. hovermode='x unified',
  341. clickmode='event'
  342. )
  343. # Update x and y axis labels
  344. fig.update_xaxes(title_text='Timestamp', row=1, col=1)
  345. fig.update_xaxes(title_text='Generation TPS', row=1, col=2)
  346. fig.update_yaxes(title_text='Tokens per Second', row=1, col=1)
  347. fig.update_yaxes(title_text='Count', row=1, col=2)
  348. graphs.append(html.Div([
  349. dcc.Graph(
  350. figure=fig,
  351. id={'type': 'dynamic-graph', 'index': config_name},
  352. config={'displayModeBar': True}
  353. )
  354. ]))
  355. return graphs, config_data, trigger_sound
  356. @app.callback(
  357. Output('graphs-container', 'children', allow_duplicate=True),
  358. Input({'type': 'dynamic-graph', 'index': dash.ALL}, 'clickData'),
  359. prevent_initial_call=True
  360. )
  361. def handle_click(clickData):
  362. # If you add any async operations here, wrap them with asyncio.run()
  363. if clickData and clickData[0] and clickData[0]['points'][0].get('customdata'):
  364. run_id = clickData[0]['points'][0]['customdata']
  365. url = f'https://github.com/exo-explore/exo/actions/runs/{run_id}'
  366. import webbrowser
  367. webbrowser.open_new_tab(url)
  368. return dash.no_update
  369. app.clientside_callback(
  370. """
  371. function(trigger, test_clicks) {
  372. if (!trigger && !test_clicks) return window.dash_clientside.no_update;
  373. if (test_clicks > 0 && dash_clientside.callback_context.triggered[0].prop_id.includes('test-sound-button')) {
  374. console.log('Test button clicked');
  375. const audio = document.getElementById('startup-sound');
  376. if (audio) {
  377. audio.currentTime = 0;
  378. audio.play().catch(e => console.log('Error playing audio:', e));
  379. }
  380. } else if (trigger) {
  381. console.log('Audio trigger received:', trigger);
  382. if (trigger === 'success') {
  383. console.log('Playing success sound');
  384. const audio = document.getElementById('success-sound');
  385. if (audio) {
  386. audio.currentTime = 0;
  387. audio.play().catch(e => console.log('Error playing success sound:', e));
  388. }
  389. } else if (trigger === 'failure') {
  390. console.log('Playing failure sound');
  391. const audio = document.getElementById('failure-sound');
  392. if (audio) {
  393. audio.currentTime = 0;
  394. audio.play().catch(e => console.log('Error playing failure sound:', e));
  395. }
  396. }
  397. }
  398. return window.dash_clientside.no_update;
  399. }
  400. """,
  401. Output('audio-trigger', 'children', allow_duplicate=True),
  402. [Input('audio-trigger', 'children'),
  403. Input('test-sound-button', 'n_clicks')],
  404. prevent_initial_call=True
  405. )
  406. if __name__ == '__main__':
  407. import sys
  408. if '--generate' in sys.argv:
  409. asyncio.run(generate_best())
  410. else:
  411. app.run_server(debug=True)