Alex Cheema 5 месяцев назад
Родитель
Сommit
16651a3506
1 измененных файлов с 16 добавлено и 24 удалено
  1. 16 24
      extra/dashboard/dashboard.py

+ 16 - 24
extra/dashboard/dashboard.py

@@ -84,30 +84,29 @@ def update_graphs(n, previous_data):
   if previous_data:
     for config_name, data in config_data.items():
       if config_name in previous_data and data and previous_data[config_name]:
-        current_prompt_tps = data[-1]['prompt_tps']
-        previous_prompt_tps = previous_data[config_name][-1]['prompt_tps']
+        current_generation_tps = data[-1]['generation_tps']
+        previous_generation_tps = previous_data[config_name][-1]['generation_tps']
 
         # Add clear logging for TPS changes
-        if current_prompt_tps != previous_prompt_tps:
+        if current_generation_tps != previous_generation_tps:
           print("\n" + "="*50)
           print(f"Config: {config_name}")
-          print(f"Previous TPS: {previous_prompt_tps}")
-          print(f"Current TPS: {current_prompt_tps}")
-          print(f"Change: {current_prompt_tps - previous_prompt_tps}")
+          print(f"Previous Generation TPS: {previous_generation_tps}")
+          print(f"Current Generation TPS: {current_generation_tps}")
+          print(f"Change: {current_generation_tps - previous_generation_tps}")
 
-        if current_prompt_tps > previous_prompt_tps:
-          print("🔼 TPS INCREASED - Should play success sound")
+        if current_generation_tps > previous_generation_tps:
+          print("🔼 Generation TPS INCREASED - Should play success sound")
           trigger_sound = 'success'
-        elif current_prompt_tps < previous_prompt_tps:
-          print("🔽 TPS DECREASED - Should play failure sound")
+        elif current_generation_tps < previous_generation_tps:
+          print("🔽 Generation TPS DECREASED - Should play failure sound")
           trigger_sound = 'failure'
 
-        if current_prompt_tps != previous_prompt_tps:
+        if current_generation_tps != previous_generation_tps:
             print("="*50 + "\n")
 
   for config_name, data in config_data.items():
     timestamps = [d['timestamp'] for d in data]
-    prompt_tps = [d['prompt_tps'] for d in data]
     generation_tps = [d['generation_tps'] for d in data]
     commits = [d['commit'] for d in data]
     run_ids = [d['run_id'] for d in data]
@@ -118,16 +117,6 @@ def update_graphs(n, previous_data):
                        column_widths=[0.7, 0.3])
 
     # Time series plot (left)
-    fig.add_trace(go.Scatter(
-      x=timestamps,
-      y=prompt_tps,
-      name='Prompt TPS',
-      mode='lines+markers',
-      hovertemplate='Commit: %{text}<br>TPS: %{y}<extra></extra>',
-      text=commits,
-      customdata=run_ids
-    ), row=1, col=1)
-
     fig.add_trace(go.Scatter(
       x=timestamps,
       y=generation_tps,
@@ -135,7 +124,9 @@ def update_graphs(n, previous_data):
       mode='lines+markers',
       hovertemplate='Commit: %{text}<br>TPS: %{y}<extra></extra>',
       text=commits,
-      customdata=run_ids
+      customdata=run_ids,
+      line=dict(color='#2196F3', width=2),
+      marker=dict(color='#2196F3')
     ), row=1, col=1)
 
     # Calculate statistics
@@ -152,7 +143,8 @@ def update_graphs(n, previous_data):
       x=generation_tps,
       name='Generation TPS Distribution',
       nbinsx=10,
-      showlegend=False
+      showlegend=False,
+      marker=dict(color='#2196F3')
     ), row=1, col=2)
 
     # Add statistics as annotations