瀏覽代碼

fix ruff lint errors

Alex Cheema 9 月之前
父節點
當前提交
57b2f2a4e2

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG, DEBUG_DISCOVERY, VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 1 - 1
exo/api/__init__.py

@@ -1 +1 @@
-from exo.api.chatgpt_api import ChatGPTAPI
+from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI

+ 5 - 7
exo/api/chatgpt_api.py

@@ -85,20 +85,18 @@ async def resolve_tokenizer(model_id: str):
   try:
     if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
     return AutoTokenizer.from_pretrained(model_id)
-  except:
+  except Exception as e:
+    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
     import traceback
-
     if DEBUG >= 2: print(traceback.format_exc())
-    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
 
   try:
     if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
     return resolve_tinygrad_tokenizer(model_id)
-  except:
+  except Exception as e:
+    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
     import traceback
-
     if DEBUG >= 2: print(traceback.format_exc())
-    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
 
   if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
@@ -312,7 +310,7 @@ class ChatGPTAPI:
         if (
           request_id in self.stream_tasks
         ):  # in case there is still a stream task running, wait for it to complete
-          if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
+          if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
           try:
             await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
           except asyncio.TimeoutError:

+ 2 - 3
exo/inference/debug_inference_engine.py

@@ -1,4 +1,3 @@
-from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
@@ -19,7 +18,7 @@ async def test_inference_engine(
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
     "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
   )
-  next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     input_data=resp_full,
@@ -41,7 +40,7 @@ async def test_inference_engine(
     input_data=resp2,
     inference_state=inference_state_2,
   )
-  resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,

+ 5 - 5
exo/inference/mlx/models/sharded_llama.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass, field
-from typing import Dict, Optional, Tuple, Union
+from typing import Dict, Optional, Union
 
 import mlx.core as mx
 import mlx.nn as nn
@@ -32,11 +32,11 @@ class NormalModelArgs(BaseModelArgs):
       self.num_key_value_heads = self.num_attention_heads
 
     if self.rope_scaling:
-      if not "factor" in self.rope_scaling:
-        raise ValueError(f"rope_scaling must contain 'factor'")
+      if "factor" not in self.rope_scaling:
+        raise ValueError("rope_scaling must contain 'factor'")
       rope_type = self.rope_scaling.get("type") or self.rope_scaling.get("rope_type")
       if rope_type is None:
-        raise ValueError(f"rope_scaling must contain either 'type' or 'rope_type'")
+        raise ValueError("rope_scaling must contain either 'type' or 'rope_type'")
       if rope_type not in ["linear", "dynamic", "llama3"]:
         raise ValueError("rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'")
 
@@ -186,7 +186,7 @@ class Attention(nn.Module):
     mask: Optional[mx.array] = None,
     cache: Optional[KVCache] = None,
   ) -> mx.array:
-    B, L, D = x.shape
+    B, L, _D = x.shape
 
     queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
 

+ 1 - 1
exo/inference/mlx/sharded_model.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, Generator, Optional, Tuple
+from typing import Dict, Generator, Optional, Tuple
 
 import mlx.core as mx
 import mlx.nn as nn

+ 0 - 1
exo/inference/mlx/test_sharded_model.py

@@ -1,5 +1,4 @@
 from exo.inference.shard import Shard
-from exo.inference.mlx.sharded_model import StatefulShardedModel
 import mlx.core as mx
 import mlx.nn as nn
 from typing import Optional

+ 2 - 3
exo/inference/test_inference_engine.py

@@ -1,7 +1,6 @@
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
-from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import asyncio
 import numpy as np
 
@@ -14,7 +13,7 @@ async def test_inference_engine(
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
     "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
   )
-  next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     input_data=resp_full,
@@ -36,7 +35,7 @@ async def test_inference_engine(
     input_data=resp2,
     inference_state=inference_state_2,
   )
-  resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,

+ 3 - 3
exo/inference/tinygrad/inference.py

@@ -2,12 +2,12 @@ import asyncio
 from functools import partial
 from pathlib import Path
 from typing import List, Optional, Union
-import json, argparse, random, time
+import json
 import tiktoken
 from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
-from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
-from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict
+from tinygrad import Tensor, nn, Context, GlobalCounters
 from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine

+ 2 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -13,8 +13,8 @@ from exo.topology.device_capabilities import DeviceCapabilities
 
 
 class GRPCPeerHandle(PeerHandle):
-  def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
-    self._id = id
+  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
+    self._id = _id
     self.address = address
     self._device_capabilities = device_capabilities
     self.channel = None

+ 1 - 1
exo/orchestration/node.py

@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, List, Callable
+from typing import Optional, Tuple, List
 import numpy as np
 from abc import ABC, abstractmethod
 from exo.helpers import AsyncCallbackSystem

+ 3 - 3
exo/orchestration/standard_node.py

@@ -18,7 +18,7 @@ from exo.viz.topology_viz import TopologyViz
 class StandardNode(Node):
   def __init__(
     self,
-    id: str,
+    _id: str,
     server: Server,
     inference_engine: InferenceEngine,
     discovery: Discovery,
@@ -28,7 +28,7 @@ class StandardNode(Node):
     web_chat_url: Optional[str] = None,
     disable_tui: Optional[bool] = False,
   ):
-    self.id = id
+    self.id = _id
     self.inference_engine = inference_engine
     self.server = server
     self.discovery = discovery
@@ -358,7 +358,7 @@ class StandardNode(Node):
         continue
 
       if max_depth <= 0:
-        if DEBUG >= 2: print(f"Max depth reached. Skipping...")
+        if DEBUG >= 2: print("Max depth reached. Skipping...")
         continue
 
       try:

+ 2 - 3
exo/stats/metrics.py

@@ -1,7 +1,6 @@
 from exo.orchestration import Node
 from prometheus_client import start_http_server, Counter, Histogram
 import json
-from typing import List
 
 # Create metrics to track time spent and requests made.
 PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
@@ -14,9 +13,9 @@ def start_metrics_server(node: Node, port: int):
 
   def _on_opaque_status(request_id, opaque_status: str):
     status_data = json.loads(opaque_status)
-    type = status_data.get("type", "")
+    _type = status_data.get("type", "")
     node_id = status_data.get("node_id", "")
-    if type != "node_status":
+    if _type != "node_status":
       return
     status = status_data.get("status", "")
 

+ 3 - 3
exo/topology/device_capabilities.py

@@ -116,8 +116,8 @@ def device_capabilities() -> DeviceCapabilities:
     return linux_device_capabilities()
   else:
     return DeviceCapabilities(
-      model=f"Unknown Device",
-      chip=f"Unknown Chip",
+      model="Unknown Device",
+      chip="Unknown Chip",
       memory=psutil.virtual_memory().total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
@@ -151,7 +151,7 @@ def linux_device_capabilities() -> DeviceCapabilities:
 
   if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
   if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
-    import pynvml, pynvml_utils
+    import pynvml
 
     pynvml.nvmlInit()
     handle = pynvml.nvmlDeviceGetHandleByIndex(0)

+ 1 - 1
exo/topology/partitioning_strategy.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import List, Tuple
+from typing import List
 from dataclasses import dataclass
 from .topology import Topology
 from exo.inference.shard import Shard

+ 0 - 1
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -1,6 +1,5 @@
 from typing import List
 from .partitioning_strategy import PartitioningStrategy
-from exo.inference.shard import Shard
 from .topology import Topology
 from .partitioning_strategy import Partition
 

+ 2 - 2
exo/topology/test_device_capabilities.py

@@ -5,7 +5,7 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
 
 class TestMacDeviceCapabilities(unittest.TestCase):
   @patch("subprocess.check_output")
-  def test_mac_device_capabilities(self, mock_check_output):
+  def test_mac_device_capabilities_pro(self, mock_check_output):
     # Mock the subprocess output
     mock_check_output.return_value = b"""
 Hardware:
@@ -40,7 +40,7 @@ Activation Lock Status: Enabled
     )
 
   @patch("subprocess.check_output")
-  def test_mac_device_capabilities(self, mock_check_output):
+  def test_mac_device_capabilities_air(self, mock_check_output):
     # Mock the subprocess output
     mock_check_output.return_value = b"""
 Hardware:

+ 1 - 2
exo/viz/topology_viz.py

@@ -8,7 +8,6 @@ from rich.panel import Panel
 from rich.text import Text
 from rich.live import Live
 from rich.style import Style
-from rich.color import Color
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 
 
@@ -20,7 +19,7 @@ class TopologyViz:
     self.partitions: List[Partition] = []
 
     self.console = Console()
-    self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel.start()
 

+ 2 - 2
format.py

@@ -47,9 +47,9 @@ def adjust_indentation(content):
 def process_file(file_path, process_func):
     with open(file_path, 'r') as file:
         content = file.read()
-    
+
     modified_content = process_func(content)
-    
+
     if content != modified_content:
         with open(file_path, 'w') as file:
             file.write(modified_content)

+ 12 - 3
main.py

@@ -2,7 +2,6 @@ import argparse
 import asyncio
 import signal
 import uuid
-from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -41,11 +40,21 @@ if args.node_port is None:
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
-node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui, max_generate_tokens=args.max_generate_tokens)
+node = StandardNode(
+    args.node_id,
+    None,
+    inference_engine,
+    discovery,
+    partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
+    chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions",
+    web_chat_url=f"http://localhost:{args.chatgpt_api_port}",
+    disable_tui=args.disable_tui,
+    max_generate_tokens=args.max_generate_tokens,
+)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
-node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)