1
0
Эх сурвалжийг харах

Merge pull request #501 from exo-explore/beautifuldashboard

Beautiful dashboard
Alex Cheema 8 сар өмнө
parent
commit
3034408a40

+ 5 - 18
exo/inference/dummy_inference_engine.py

@@ -1,16 +1,9 @@
 from typing import Optional, Tuple, TYPE_CHECKING
 import numpy as np
-import random
-import string
-import asyncio
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import DummyTokenizer
 
-def random_string(length: int):
-  return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
-  
-
 class DummyInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None
@@ -19,29 +12,23 @@ class DummyInferenceEngine(InferenceEngine):
     self.eos_token_id = 0
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
+    self.num_generate_dummy_tokens = 10
     self.tokenizer = DummyTokenizer()
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     return np.array(self.tokenizer.encode(prompt))
   
   async def sample(self, x: np.ndarray) -> np.ndarray:
-    if random.random() < 0.1:
-      return np.array([self.tokenizer.eos_token_id])
-    return np.array([np.random.randint(1, self.vocab_size)])
+    if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
+    return x
 
   async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     return self.tokenizer.decode(tokens)
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
-    output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))
-    return output
+    return input_data + 1 if self.shard.is_last_layer() else input_data
 
   async def ensure_shard(self, shard: Shard):
-    if self.shard == shard:
-      return
-    # Simulate shard loading without making any API calls
-    await asyncio.sleep(0.1)  # Simulate a short delay
+    if self.shard == shard: return
     self.shard = shard
-    print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

+ 1 - 1
exo/inference/tokenizers.py

@@ -18,7 +18,7 @@ class DummyTokenizer:
     return "dummy_tokenized_prompt"
 
   def encode(self, text):
-    return np.random.randint(1, self.vocab_size, size=(1, len(text.split())))
+    return np.array([1])
 
   def decode(self, tokens):
     return "dummy" * len(tokens)

+ 233 - 80
extra/dashboard/dashboard.py

@@ -9,6 +9,9 @@ from typing import List, Dict, Optional
 from pathlib import Path
 from plotly.subplots import make_subplots
 import plotly.graph_objects as go
+import time
+import simpleaudio as sa
+from datetime import datetime
 
 class AsyncCircleCIClient:
     def __init__(self, token: str, project_slug: str):
@@ -104,6 +107,20 @@ class PackageSizeTracker:
         self.setup_logging(debug)
         self.client = AsyncCircleCIClient(token, project_slug)
         self.logger = logging.getLogger("PackageSizeTracker")
+        self.last_data_hash = None
+        self.notification_sound_path = Path(__file__).parent / "notification.wav"
+        self.debug = debug
+
+        # Sound file paths - replace these with your actual sound files
+        sounds_dir = Path(__file__).parent / "sounds"
+        self.sounds = {
+            'lines_up': sounds_dir / "lines_increased.wav",
+            'lines_down': sounds_dir / "lines_decreased.wav",
+            'tokens_up': sounds_dir / "tokens_increased.wav",
+            'tokens_down': sounds_dir / "tokens_decreased.wav",
+            'size_up': sounds_dir / "size_increased.wav",
+            'size_down': sounds_dir / "size_decreased.wav"
+        }
 
     def setup_logging(self, debug: bool):
         level = logging.DEBUG if debug else logging.INFO
@@ -115,20 +132,35 @@ class PackageSizeTracker:
 
     def extract_commit_info(self, pipeline: Dict) -> Optional[Dict]:
         try:
-            if 'trigger_parameters' in pipeline:
-                github_app = pipeline['trigger_parameters'].get('github_app', {})
-                if github_app:
-                    return {
-                        'commit_hash': github_app.get('checkout_sha'),
-                        'web_url': f"{github_app.get('repo_url')}/commit/{github_app.get('checkout_sha')}"
-                    }
-
-                git_params = pipeline['trigger_parameters'].get('git', {})
-                if git_params:
-                    return {
-                        'commit_hash': git_params.get('checkout_sha'),
-                        'web_url': f"{git_params.get('repo_url')}/commit/{git_params.get('checkout_sha')}"
-                    }
+            # Extract from github_app first (preferred)
+            if 'trigger_parameters' in pipeline and 'github_app' in pipeline['trigger_parameters']:
+                github_app = pipeline['trigger_parameters']['github_app']
+                return {
+                    'commit_hash': github_app.get('checkout_sha'),
+                    'web_url': f"{github_app.get('repo_url')}/commit/{github_app.get('checkout_sha')}",
+                    'branch': github_app.get('branch', 'unknown'),
+                    'author': {
+                        'name': github_app.get('commit_author_name'),
+                        'email': github_app.get('commit_author_email'),
+                        'username': github_app.get('user_username')
+                    },
+                    'message': github_app.get('commit_message')
+                }
+
+            # Fallback to git parameters
+            if 'trigger_parameters' in pipeline and 'git' in pipeline['trigger_parameters']:
+                git = pipeline['trigger_parameters']['git']
+                return {
+                    'commit_hash': git.get('checkout_sha'),
+                    'web_url': f"{git.get('repo_url')}/commit/{git.get('checkout_sha')}",
+                    'branch': git.get('branch', 'unknown'),
+                    'author': {
+                        'name': git.get('commit_author_name'),
+                        'email': git.get('commit_author_email'),
+                        'username': git.get('author_login')
+                    },
+                    'message': git.get('commit_message')
+                }
 
             self.logger.warning(f"Could not find commit info in pipeline {pipeline['id']}")
             return None
@@ -143,13 +175,17 @@ class PackageSizeTracker:
             if not commit_info:
                 return None
 
-            jobs = await self.client.get_workflow_jobs(session, pipeline["id"])
+            data_point = {
+                "commit_hash": commit_info['commit_hash'],
+                "commit_url": commit_info['web_url'],
+                "timestamp": pipeline.get("created_at", pipeline.get("updated_at")),
+                "pipeline_status": pipeline.get("state", "unknown"),
+                "branch": commit_info['branch'],
+                "author": commit_info['author'],
+                "commit_message": commit_info['message']
+            }
 
-            # Add test status check
-            test_job = next(
-                (j for j in jobs if j["name"] == "test" and j["status"] in ["success", "failed"]),
-                None
-            )
+            jobs = await self.client.get_workflow_jobs(session, pipeline["id"])
 
             # Get package size data
             size_job = next(
@@ -174,13 +210,6 @@ class PackageSizeTracker:
                 self.logger.debug(f"No relevant jobs found for pipeline {pipeline['id']}")
                 return None
 
-            data_point = {
-                "commit_hash": commit_info['commit_hash'],
-                "commit_url": commit_info['web_url'],
-                "timestamp": pipeline.get("created_at", pipeline.get("updated_at")),
-                "tests_passing": test_job["status"] == "success" if test_job else None
-            }
-
             # Process benchmark data if available
             if benchmark_job:
                 benchmark_artifacts = await self.client.get_artifacts(session, benchmark_job["job_number"])
@@ -283,19 +312,34 @@ class PackageSizeTracker:
             self.logger.error("No data to generate report from!")
             return None
 
+        # Get latest pipeline status based on errors
+        latest_main_pipeline = next((d for d in data if d.get('branch') == 'main'), None)
+        latest_pipeline_status = 'success' if latest_main_pipeline and not latest_main_pipeline.get('errors') else 'failure'
+
+        # Log the pipeline status
+        if latest_main_pipeline:
+            self.logger.info(
+                f"Latest main branch pipeline status: {latest_pipeline_status} "
+                f"(commit: {latest_main_pipeline['commit_hash'][:7]})"
+            )
+        else:
+            self.logger.warning("No pipeline data found for main branch")
+
+        # Convert output_dir to Path object
+        output_dir = Path(output_dir)
+
+        # Create output directory if it doesn't exist
+        output_dir.mkdir(parents=True, exist_ok=True)
+
         # Create separate dataframes for each metric
         df_size = pd.DataFrame([d for d in data if 'total_size_mb' in d])
         df_lines = pd.DataFrame([d for d in data if 'total_lines' in d])
         df_benchmark = pd.DataFrame([d for d in data if 'tokens_per_second' in d])
 
-        # Ensure output directory exists
-        output_dir = Path(output_dir)
-        output_dir.mkdir(parents=True, exist_ok=True)
-
         # Create a single figure with subplots
         fig = make_subplots(
             rows=3, cols=2,
-            subplot_titles=('Test Status', 'Package Size', '', 'Line Count', '', 'Tokens per Second'),
+            subplot_titles=('', 'Package Size', '', 'Line Count', '', 'Tokens per Second'),
             vertical_spacing=0.2,
             column_widths=[0.2, 0.8],
             specs=[[{"type": "indicator"}, {"type": "scatter"}],
@@ -303,27 +347,6 @@ class PackageSizeTracker:
                    [None, {"type": "scatter"}]]
         )
 
-        # Add test status indicator if we have data
-        latest_test_status = next((d["tests_passing"] for d in reversed(data) if "tests_passing" in d), None)
-        if latest_test_status is not None:
-            fig.add_trace(
-                go.Indicator(
-                    mode="gauge",
-                    gauge={
-                        "shape": "bullet",
-                        "axis": {"visible": False},
-                        "bar": {"color": "green" if latest_test_status else "red"},
-                        "bgcolor": "white",
-                        "steps": [
-                            {"range": [0, 1], "color": "lightgray"}
-                        ]
-                    },
-                    value=1,
-                    title={"text": "Tests<br>Status"}
-                ),
-                row=1, col=1
-            )
-
         # Add package size trace if we have data
         if not df_size.empty:
             df_size['timestamp'] = pd.to_datetime(df_size['timestamp'])
@@ -510,9 +533,36 @@ class PackageSizeTracker:
                     height: 350px;
                     display: flex;
                     flex-direction: column;
+                    align-items: center;
                     justify-content: center;
                 }}
 
+                .traffic-light {{
+                    width: 150px;
+                    height: 150px;
+                    border-radius: 50%;
+                    margin: 20px;
+                    box-shadow: 0 0 20px rgba(0,0,0,0.2);
+                    position: relative;
+                }}
+
+                .traffic-light.success {{
+                    background: #2ecc71;  /* Bright green */
+                    border: 8px solid #27ae60;  /* Darker green border */
+                }}
+
+                .traffic-light.failure {{
+                    background: #e74c3c;  /* Bright red */
+                    border: 8px solid #c0392b;  /* Darker red border */
+                }}
+
+                .status-text {{
+                    font-size: 24px;
+                    font-weight: bold;
+                    margin-top: 20px;
+                    color: #2c3e50;
+                }}
+
                 /* Override Plotly's default margins */
                 .js-plotly-plot .plotly {{
                     margin: 0 !important;
@@ -534,8 +584,11 @@ class PackageSizeTracker:
 
             <div class="dashboard-grid">
                 <div class="status-container">
-                    <div class="chart-title">Test Status</div>
-                    <div id="status-chart"></div>
+                    <div class="chart-title">Pipeline Status</div>
+                    <div class="traffic-light {'success' if latest_pipeline_status == 'success' else 'failure'}"></div>
+                    <div class="status-text">
+                        {'✓ Pipeline Passing' if latest_pipeline_status == 'success' else '✗ Pipeline Failing'}
+                    </div>
                 </div>
                 <div class="chart-row">
                     <div class="chart-box">
@@ -567,18 +620,6 @@ class PackageSizeTracker:
                 const originalData = {fig.to_json()};
 
                 function initializeCharts() {{
-                    // Create the status indicator
-                    if (originalData.data[0].type === 'indicator') {{
-                        Plotly.newPlot('status-chart',
-                            [originalData.data[0]],
-                            {{
-                                ...originalData.layout,
-                                margin: {{ t: 0, b: 0, l: 0, r: 0 }},
-                                height: 280
-                            }}
-                        );
-                    }}
-
                     // Create the size trend chart
                     const sizeTrace = originalData.data.find(trace => trace.name === 'Package Size');
                     if (sizeTrace) {{
@@ -687,13 +728,13 @@ class PackageSizeTracker:
 
                     // Update the ranges
                     const sizeUpdateLayout = {{}};
-                    sizeUpdateLayout[`${{sizeXAxisName}}.range`] = [startDate, endDate];
+                    sizeUpdateLayout[`{{sizeXAxisName}}.range`] = [startDate, endDate];
 
                     const linesUpdateLayout = {{}};
-                    linesUpdateLayout[`${{linesXAxisName}}.range`] = [startDate, endDate];
+                    linesUpdateLayout[`{{linesXAxisName}}.range`] = [startDate, endDate];
 
                     const tokensUpdateLayout = {{}};
-                    tokensUpdateLayout[`${{tokensXAxisName}}.range`] = [startDate, endDate];
+                    tokensUpdateLayout[`{{tokensXAxisName}}.range`] = [startDate, endDate];
 
                     // Update both charts
                     Plotly.relayout('size-chart', sizeUpdateLayout)
@@ -882,11 +923,129 @@ class PackageSizeTracker:
 
         print("\n")
 
+    def _calculate_data_hash(self, data: List[Dict]) -> str:
+        """Calculate a hash of the data to detect changes"""
+        return hash(str(sorted([
+            (d.get('commit_hash'), d.get('timestamp'))
+            for d in data
+        ])))
+
+    def _play_sound(self, sound_key: str):
+        """Play a specific notification sound"""
+        try:
+            sound_path = self.sounds.get(sound_key)
+            if sound_path and sound_path.exists():
+                wave_obj = sa.WaveObject.from_wave_file(str(sound_path))
+                wave_obj.play()
+            else:
+                self.logger.warning(f"Sound file not found: {sound_key} at {sound_path}")
+        except Exception as e:
+            self.logger.error(f"Failed to play sound {sound_key}: {e}")
+
+    def _check_metrics_changes(self, current_data: List[Dict], previous_data: List[Dict]):
+        """Check for specific metric changes and play appropriate sounds"""
+        if not previous_data:
+            return
+
+        # Get latest data points
+        current = current_data[-1]
+        previous = previous_data[-1]
+
+        # Check line count changes
+        if 'total_lines' in current and 'total_lines' in previous:
+            diff = current['total_lines'] - previous['total_lines']
+            if diff > 0:
+                self.logger.info(f"Lines of code increased by {diff:,}")
+                self._play_sound('lines_up')
+            elif diff < 0:
+                self.logger.info(f"Lines of code decreased by {abs(diff):,}")
+                self._play_sound('lines_down')
+
+        # Check tokens per second changes
+        if 'tokens_per_second' in current and 'tokens_per_second' in previous:
+            diff = current['tokens_per_second'] - previous['tokens_per_second']
+            if diff > 0:
+                self.logger.info(f"Tokens per second increased by {diff:.2f}")
+                self._play_sound('tokens_up')
+            elif diff < 0:
+                self.logger.info(f"Tokens per second decreased by {abs(diff):.2f}")
+                self._play_sound('tokens_down')
+
+        # Check package size changes
+        if 'total_size_mb' in current and 'total_size_mb' in previous:
+            diff = current['total_size_mb'] - previous['total_size_mb']
+            if diff > 0:
+                self.logger.info(f"Package size increased by {diff:.2f}MB")
+                self._play_sound('size_up')
+            elif diff < 0:
+                self.logger.info(f"Package size decreased by {abs(diff):.2f}MB")
+                self._play_sound('size_down')
+
+    async def run_dashboard(self, update_interval: int = 30):
+        """Run the dashboard with periodic updates"""
+        try:
+            # Force convert update_interval to float and log its type
+            update_interval = float(update_interval)
+            self.logger.debug(f"Update interval type: {type(update_interval)}, value: {update_interval}")
+        except ValueError as e:
+            self.logger.error(f"Failed to convert update_interval to float: {update_interval}")
+            raise
+
+        self.logger.info(f"Starting real-time dashboard with {update_interval}s updates")
+        previous_data = None
+
+        while True:
+            try:
+                start_time = time.time()
+                self.logger.debug(f"Start time type: {type(start_time)}, value: {start_time}")
+
+                # Collect new data
+                current_data = await self.collect_data()
+                if not current_data:
+                    self.logger.warning("No data collected")
+                    await asyncio.sleep(update_interval)
+                    continue
+
+                # Generate report
+                report_path = self.generate_report(current_data)
+                if report_path:
+                    self.logger.info(
+                        f"Dashboard updated at {datetime.now().strftime('%H:%M:%S')}"
+                    )
+
+                    # Check for metric changes and play appropriate sounds
+                    self._check_metrics_changes(current_data, previous_data)
+
+                # Update previous data
+                previous_data = current_data
+
+                # Calculate sleep time with explicit type conversion and logging
+                elapsed = float(time.time() - start_time)
+                self.logger.debug(f"Elapsed time type: {type(elapsed)}, value: {elapsed}")
+                sleep_time = max(0.0, float(update_interval) - elapsed)
+                self.logger.debug(f"Sleep time type: {type(sleep_time)}, value: {sleep_time}")
+
+                await asyncio.sleep(sleep_time)
+
+            except Exception as e:
+                self.logger.error(f"Error in dashboard update loop: {e}", exc_info=True)
+                if self.debug:
+                    raise
+                await asyncio.sleep(float(update_interval))
+
 async def main():
     token = os.getenv("CIRCLECI_TOKEN")
     project_slug = os.getenv("CIRCLECI_PROJECT_SLUG")
     debug = os.getenv("DEBUG", "").lower() in ("true", "1", "yes")
 
+    try:
+        # Get update interval from environment or use default
+        update_interval = float(os.getenv("UPDATE_INTERVAL", "30"))
+        print(f"Update interval type: {type(update_interval)}, value: {update_interval}")  # Debug print
+    except ValueError as e:
+        print(f"Error converting UPDATE_INTERVAL to float: {os.getenv('UPDATE_INTERVAL')}")
+        update_interval = 30.0
+
     if not token or not project_slug:
         print("Error: Please set CIRCLECI_TOKEN and CIRCLECI_PROJECT_SLUG environment variables")
         return
@@ -894,17 +1053,11 @@ async def main():
     tracker = PackageSizeTracker(token, project_slug, debug)
 
     try:
-        data = await tracker.collect_data()
-        if not data:
-            print("No data found!")
-            return
-
-        report_path = tracker.generate_report(data)
-        if report_path:
-            print(f"\nDetailed report available at: {report_path}")
-
+        await tracker.run_dashboard(update_interval)
+    except KeyboardInterrupt:
+        print("\nDashboard stopped by user")
     except Exception as e:
-        logging.error(f"Error: {str(e)}")
+        logging.error(f"Error: {str(e)}", exc_info=True)
         if debug:
             raise
 

+ 2 - 1
extra/dashboard/requirements.txt

@@ -1,4 +1,5 @@
 plotly
 pandas
 requests
-aiohttp
+aiohttp
+simpleaudio