소스 검색

fix ruff lint errors

Alex Cheema 9 달 전
부모
커밋
57b2f2a4e2

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG, DEBUG_DISCOVERY, VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 1 - 1
exo/api/__init__.py

@@ -1 +1 @@
-from exo.api.chatgpt_api import ChatGPTAPI
+from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI

+ 5 - 7
exo/api/chatgpt_api.py

@@ -85,20 +85,18 @@ async def resolve_tokenizer(model_id: str):
   try:
   try:
     if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
     if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
     return AutoTokenizer.from_pretrained(model_id)
     return AutoTokenizer.from_pretrained(model_id)
-  except:
+  except Exception as e:
+    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
     import traceback
     import traceback
-
     if DEBUG >= 2: print(traceback.format_exc())
     if DEBUG >= 2: print(traceback.format_exc())
-    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
 
 
   try:
   try:
     if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
     if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
     return resolve_tinygrad_tokenizer(model_id)
     return resolve_tinygrad_tokenizer(model_id)
-  except:
+  except Exception as e:
+    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
     import traceback
     import traceback
-
     if DEBUG >= 2: print(traceback.format_exc())
     if DEBUG >= 2: print(traceback.format_exc())
-    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
 
 
   if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
   if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
   from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
@@ -312,7 +310,7 @@ class ChatGPTAPI:
         if (
         if (
           request_id in self.stream_tasks
           request_id in self.stream_tasks
         ):  # in case there is still a stream task running, wait for it to complete
         ):  # in case there is still a stream task running, wait for it to complete
-          if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
+          if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
           try:
           try:
             await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
             await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
           except asyncio.TimeoutError:
           except asyncio.TimeoutError:

+ 2 - 3
exo/inference/debug_inference_engine.py

@@ -1,4 +1,3 @@
-from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
@@ -19,7 +18,7 @@ async def test_inference_engine(
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
     "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
     "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
   )
   )
-  next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     input_data=resp_full,
     input_data=resp_full,
@@ -41,7 +40,7 @@ async def test_inference_engine(
     input_data=resp2,
     input_data=resp2,
     inference_state=inference_state_2,
     inference_state=inference_state_2,
   )
   )
-  resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,
     input_data=resp3,

+ 5 - 5
exo/inference/mlx/models/sharded_llama.py

@@ -1,5 +1,5 @@
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
-from typing import Dict, Optional, Tuple, Union
+from typing import Dict, Optional, Union
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
@@ -32,11 +32,11 @@ class NormalModelArgs(BaseModelArgs):
       self.num_key_value_heads = self.num_attention_heads
       self.num_key_value_heads = self.num_attention_heads
 
 
     if self.rope_scaling:
     if self.rope_scaling:
-      if not "factor" in self.rope_scaling:
-        raise ValueError(f"rope_scaling must contain 'factor'")
+      if "factor" not in self.rope_scaling:
+        raise ValueError("rope_scaling must contain 'factor'")
       rope_type = self.rope_scaling.get("type") or self.rope_scaling.get("rope_type")
       rope_type = self.rope_scaling.get("type") or self.rope_scaling.get("rope_type")
       if rope_type is None:
       if rope_type is None:
-        raise ValueError(f"rope_scaling must contain either 'type' or 'rope_type'")
+        raise ValueError("rope_scaling must contain either 'type' or 'rope_type'")
       if rope_type not in ["linear", "dynamic", "llama3"]:
       if rope_type not in ["linear", "dynamic", "llama3"]:
         raise ValueError("rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'")
         raise ValueError("rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'")
 
 
@@ -186,7 +186,7 @@ class Attention(nn.Module):
     mask: Optional[mx.array] = None,
     mask: Optional[mx.array] = None,
     cache: Optional[KVCache] = None,
     cache: Optional[KVCache] = None,
   ) -> mx.array:
   ) -> mx.array:
-    B, L, D = x.shape
+    B, L, _D = x.shape
 
 
     queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
     queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
 
 

+ 1 - 1
exo/inference/mlx/sharded_model.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, Generator, Optional, Tuple
+from typing import Dict, Generator, Optional, Tuple
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn

+ 0 - 1
exo/inference/mlx/test_sharded_model.py

@@ -1,5 +1,4 @@
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
-from exo.inference.mlx.sharded_model import StatefulShardedModel
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from typing import Optional
 from typing import Optional

+ 2 - 3
exo/inference/test_inference_engine.py

@@ -1,7 +1,6 @@
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
-from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import asyncio
 import asyncio
 import numpy as np
 import numpy as np
 
 
@@ -14,7 +13,7 @@ async def test_inference_engine(
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
     "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
     "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
   )
   )
-  next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     input_data=resp_full,
     input_data=resp_full,
@@ -36,7 +35,7 @@ async def test_inference_engine(
     input_data=resp2,
     input_data=resp2,
     inference_state=inference_state_2,
     inference_state=inference_state_2,
   )
   )
-  resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,
     input_data=resp3,

+ 3 - 3
exo/inference/tinygrad/inference.py

@@ -2,12 +2,12 @@ import asyncio
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
 from typing import List, Optional, Union
 from typing import List, Optional, Union
-import json, argparse, random, time
+import json
 import tiktoken
 import tiktoken
 from tiktoken.load import load_tiktoken_bpe
 from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
-from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
-from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict
+from tinygrad import Tensor, nn, Context, GlobalCounters
 from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine

+ 2 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -13,8 +13,8 @@ from exo.topology.device_capabilities import DeviceCapabilities
 
 
 
 
 class GRPCPeerHandle(PeerHandle):
 class GRPCPeerHandle(PeerHandle):
-  def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
-    self._id = id
+  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
+    self._id = _id
     self.address = address
     self.address = address
     self._device_capabilities = device_capabilities
     self._device_capabilities = device_capabilities
     self.channel = None
     self.channel = None

+ 1 - 1
exo/orchestration/node.py

@@ -1,4 +1,4 @@
-from typing import Optional, Tuple, List, Callable
+from typing import Optional, Tuple, List
 import numpy as np
 import numpy as np
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem

+ 3 - 3
exo/orchestration/standard_node.py

@@ -18,7 +18,7 @@ from exo.viz.topology_viz import TopologyViz
 class StandardNode(Node):
 class StandardNode(Node):
   def __init__(
   def __init__(
     self,
     self,
-    id: str,
+    _id: str,
     server: Server,
     server: Server,
     inference_engine: InferenceEngine,
     inference_engine: InferenceEngine,
     discovery: Discovery,
     discovery: Discovery,
@@ -28,7 +28,7 @@ class StandardNode(Node):
     web_chat_url: Optional[str] = None,
     web_chat_url: Optional[str] = None,
     disable_tui: Optional[bool] = False,
     disable_tui: Optional[bool] = False,
   ):
   ):
-    self.id = id
+    self.id = _id
     self.inference_engine = inference_engine
     self.inference_engine = inference_engine
     self.server = server
     self.server = server
     self.discovery = discovery
     self.discovery = discovery
@@ -358,7 +358,7 @@ class StandardNode(Node):
         continue
         continue
 
 
       if max_depth <= 0:
       if max_depth <= 0:
-        if DEBUG >= 2: print(f"Max depth reached. Skipping...")
+        if DEBUG >= 2: print("Max depth reached. Skipping...")
         continue
         continue
 
 
       try:
       try:

+ 2 - 3
exo/stats/metrics.py

@@ -1,7 +1,6 @@
 from exo.orchestration import Node
 from exo.orchestration import Node
 from prometheus_client import start_http_server, Counter, Histogram
 from prometheus_client import start_http_server, Counter, Histogram
 import json
 import json
-from typing import List
 
 
 # Create metrics to track time spent and requests made.
 # Create metrics to track time spent and requests made.
 PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
 PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
@@ -14,9 +13,9 @@ def start_metrics_server(node: Node, port: int):
 
 
   def _on_opaque_status(request_id, opaque_status: str):
   def _on_opaque_status(request_id, opaque_status: str):
     status_data = json.loads(opaque_status)
     status_data = json.loads(opaque_status)
-    type = status_data.get("type", "")
+    _type = status_data.get("type", "")
     node_id = status_data.get("node_id", "")
     node_id = status_data.get("node_id", "")
-    if type != "node_status":
+    if _type != "node_status":
       return
       return
     status = status_data.get("status", "")
     status = status_data.get("status", "")
 
 

+ 3 - 3
exo/topology/device_capabilities.py

@@ -116,8 +116,8 @@ def device_capabilities() -> DeviceCapabilities:
     return linux_device_capabilities()
     return linux_device_capabilities()
   else:
   else:
     return DeviceCapabilities(
     return DeviceCapabilities(
-      model=f"Unknown Device",
-      chip=f"Unknown Chip",
+      model="Unknown Device",
+      chip="Unknown Chip",
       memory=psutil.virtual_memory().total // 2**20,
       memory=psutil.virtual_memory().total // 2**20,
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       flops=DeviceFlops(fp32=0, fp16=0, int8=0),
     )
     )
@@ -151,7 +151,7 @@ def linux_device_capabilities() -> DeviceCapabilities:
 
 
   if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
   if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
   if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
   if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
-    import pynvml, pynvml_utils
+    import pynvml
 
 
     pynvml.nvmlInit()
     pynvml.nvmlInit()
     handle = pynvml.nvmlDeviceGetHandleByIndex(0)
     handle = pynvml.nvmlDeviceGetHandleByIndex(0)

+ 1 - 1
exo/topology/partitioning_strategy.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import List, Tuple
+from typing import List
 from dataclasses import dataclass
 from dataclasses import dataclass
 from .topology import Topology
 from .topology import Topology
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard

+ 0 - 1
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -1,6 +1,5 @@
 from typing import List
 from typing import List
 from .partitioning_strategy import PartitioningStrategy
 from .partitioning_strategy import PartitioningStrategy
-from exo.inference.shard import Shard
 from .topology import Topology
 from .topology import Topology
 from .partitioning_strategy import Partition
 from .partitioning_strategy import Partition
 
 

+ 2 - 2
exo/topology/test_device_capabilities.py

@@ -5,7 +5,7 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
 
 
 class TestMacDeviceCapabilities(unittest.TestCase):
 class TestMacDeviceCapabilities(unittest.TestCase):
   @patch("subprocess.check_output")
   @patch("subprocess.check_output")
-  def test_mac_device_capabilities(self, mock_check_output):
+  def test_mac_device_capabilities_pro(self, mock_check_output):
     # Mock the subprocess output
     # Mock the subprocess output
     mock_check_output.return_value = b"""
     mock_check_output.return_value = b"""
 Hardware:
 Hardware:
@@ -40,7 +40,7 @@ Activation Lock Status: Enabled
     )
     )
 
 
   @patch("subprocess.check_output")
   @patch("subprocess.check_output")
-  def test_mac_device_capabilities(self, mock_check_output):
+  def test_mac_device_capabilities_air(self, mock_check_output):
     # Mock the subprocess output
     # Mock the subprocess output
     mock_check_output.return_value = b"""
     mock_check_output.return_value = b"""
 Hardware:
 Hardware:

+ 1 - 2
exo/viz/topology_viz.py

@@ -8,7 +8,6 @@ from rich.panel import Panel
 from rich.text import Text
 from rich.text import Text
 from rich.live import Live
 from rich.live import Live
 from rich.style import Style
 from rich.style import Style
-from rich.color import Color
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 
 
 
 
@@ -20,7 +19,7 @@ class TopologyViz:
     self.partitions: List[Partition] = []
     self.partitions: List[Partition] = []
 
 
     self.console = Console()
     self.console = Console()
-    self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
     self.live_panel.start()
     self.live_panel.start()
 
 

+ 2 - 2
format.py

@@ -47,9 +47,9 @@ def adjust_indentation(content):
 def process_file(file_path, process_func):
 def process_file(file_path, process_func):
     with open(file_path, 'r') as file:
     with open(file_path, 'r') as file:
         content = file.read()
         content = file.read()
-    
+
     modified_content = process_func(content)
     modified_content = process_func(content)
-    
+
     if content != modified_content:
     if content != modified_content:
         with open(file_path, 'w') as file:
         with open(file_path, 'w') as file:
             file.write(modified_content)
             file.write(modified_content)

+ 12 - 3
main.py

@@ -2,7 +2,6 @@ import argparse
 import asyncio
 import asyncio
 import signal
 import signal
 import uuid
 import uuid
-from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -41,11 +40,21 @@ if args.node_port is None:
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
-node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui, max_generate_tokens=args.max_generate_tokens)
+node = StandardNode(
+    args.node_id,
+    None,
+    inference_engine,
+    discovery,
+    partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
+    chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions",
+    web_chat_url=f"http://localhost:{args.chatgpt_api_port}",
+    disable_tui=args.disable_tui,
+    max_generate_tokens=args.max_generate_tokens,
+)
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
-node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
 if args.prometheus_client_port:
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
     start_metrics_server(node, args.prometheus_client_port)