Browse Source

yapf set blank_line_before_nested_class_or_def to false

Alex Cheema 10 months ago
parent
commit
14f2846a9c
38 changed files with 2 additions and 69 deletions
  1. 2 1
      .style.yapf
  2. 0 5
      exo/api/chatgpt_api.py
  3. 0 2
      exo/download/hf/hf_shard_download.py
  4. 0 1
      exo/download/shard_download.py
  5. 0 3
      exo/helpers.py
  6. 0 1
      exo/inference/inference_engine.py
  7. 0 1
      exo/inference/mlx/models/base.py
  8. 0 2
      exo/inference/mlx/models/deepseek_v2.py
  9. 0 2
      exo/inference/mlx/models/llama.py
  10. 0 14
      exo/inference/mlx/models/llava.py
  11. 0 1
      exo/inference/mlx/sharded_inference_engine.py
  12. 0 2
      exo/inference/mlx/sharded_model.py
  13. 0 1
      exo/inference/mlx/sharded_utils.py
  14. 0 1
      exo/inference/mlx/test_sharded_model.py
  15. 0 1
      exo/inference/tinygrad/inference.py
  16. 0 5
      exo/inference/tinygrad/models/llama.py
  17. 0 1
      exo/inference/tinygrad/tinygrad_helpers.py
  18. 0 1
      exo/networking/discovery.py
  19. 0 2
      exo/networking/grpc/grpc_discovery.py
  20. 0 1
      exo/networking/grpc/grpc_peer_handle.py
  21. 0 1
      exo/networking/grpc/grpc_server.py
  22. 0 3
      exo/networking/grpc/node_service_pb2_grpc.py
  23. 0 1
      exo/networking/grpc/test_grpc_discovery.py
  24. 0 1
      exo/networking/peer_handle.py
  25. 0 1
      exo/networking/server.py
  26. 0 1
      exo/orchestration/node.py
  27. 0 2
      exo/orchestration/standard_node.py
  28. 0 1
      exo/orchestration/test_node.py
  29. 0 1
      exo/test_callbacks.py
  30. 0 1
      exo/topology/partitioning_strategy.py
  31. 0 1
      exo/topology/ring_memory_weighted_partitioning_strategy.py
  32. 0 1
      exo/topology/test_device_capabilities.py
  33. 0 1
      exo/topology/test_map_partitions.py
  34. 0 1
      exo/topology/test_ring_memory_weighted_partitioning_strategy.py
  35. 0 1
      exo/topology/topology.py
  36. 0 1
      exo/viz/test_topology_viz.py
  37. 0 1
      exo/viz/topology_viz.py
  38. 0 1
      extra/download_hf.py

+ 2 - 1
.style.yapf

@@ -10,4 +10,5 @@ continuation_indent_width = 2
 indent_dictionary_value = True
 allow_multiline_dictionary_keys = True
 each_dict_entry_on_separate_line = False
-allow_multiline_lambdas = True
+allow_multiline_lambdas = True
+blank_line_before_nested_class_or_def = False

+ 0 - 5
exo/api/chatgpt_api.py

@@ -18,7 +18,6 @@ from typing import Callable
 
 
 class Message:
-
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
     self.role = role
     self.content = content
@@ -28,7 +27,6 @@ class Message:
 
 
 class ChatCompletionRequest:
-
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
     self.messages = messages
@@ -148,7 +146,6 @@ def parse_chat_request(data: dict):
 
 
 class PromptSession:
-
   def __init__(self, request_id: str, timestamp: int, prompt: str):
     self.request_id = request_id
     self.timestamp = timestamp
@@ -156,7 +153,6 @@ class PromptSession:
 
 
 class ChatGPTAPI:
-
   def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
@@ -183,7 +179,6 @@ class ChatGPTAPI:
     self.app.middlewares.append(self.log_request)
 
   async def log_request(self, app, handler):
-
     async def middleware(request):
       if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
       return await handler(request)

+ 0 - 2
exo/download/hf/hf_shard_download.py

@@ -10,7 +10,6 @@ from exo.helpers import AsyncCallbackSystem, DEBUG
 
 
 class HFShardDownloader(ShardDownloader):
-
   def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
     self.quick_check = quick_check
     self.max_parallel_downloads = max_parallel_downloads
@@ -63,7 +62,6 @@ class HFShardDownloader(ShardDownloader):
         self.active_downloads.pop(shard)
 
   async def _download_shard(self, shard: Shard) -> Path:
-
     async def wrapped_progress_callback(event: RepoProgressEvent):
       self._on_progress.trigger_all(shard, event)
 

+ 0 - 1
exo/download/shard_download.py

@@ -7,7 +7,6 @@ from exo.helpers import AsyncCallbackSystem
 
 
 class ShardDownloader(ABC):
-
   @abstractmethod
   async def ensure_shard(self, shard: Shard) -> Path:
     """

+ 0 - 3
exo/helpers.py

@@ -91,7 +91,6 @@ K = TypeVar("K")
 
 
 class AsyncCallback(Generic[T]):
-
   def __init__(self) -> None:
     self.condition: asyncio.Condition = asyncio.Condition()
     self.result: Optional[Tuple[T, ...]] = None
@@ -118,7 +117,6 @@ class AsyncCallback(Generic[T]):
 
 
 class AsyncCallbackSystem(Generic[K, T]):
-
   def __init__(self) -> None:
     self.callbacks: Dict[K, AsyncCallback[T]] = {}
 
@@ -145,7 +143,6 @@ V = TypeVar('V')
 
 
 class PrefixDict(Generic[K, V]):
-
   def __init__(self):
     self.items: Dict[K, V] = {}
 

+ 0 - 1
exo/inference/inference_engine.py

@@ -7,7 +7,6 @@ from .shard import Shard
 
 
 class InferenceEngine(ABC):
-
   @abstractmethod
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     pass

+ 0 - 1
exo/inference/mlx/models/base.py

@@ -5,6 +5,5 @@ from mlx_lm.models.base import KVCache
 
 
 class IdentityBlock(nn.Module):
-
   def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
     return x

+ 0 - 2
exo/inference/mlx/models/deepseek_v2.py

@@ -24,7 +24,6 @@ class ModelArgs(ModelArgs):
 
 
 class DeepseekV2Model(nn.Module):
-
   def __init__(self, config: ModelArgs):
     super().__init__()
     self.args = config
@@ -71,7 +70,6 @@ class DeepseekV2Model(nn.Module):
 
 
 class Model(nn.Module):
-
   def __init__(self, config: ModelArgs):
     super().__init__()
     self.args = config

+ 0 - 2
exo/inference/mlx/models/llama.py

@@ -26,7 +26,6 @@ class ModelArgs(ModelArgs):
 
 
 class LlamaModel(nn.Module):
-
   def __init__(self, args: ModelArgs):
     super().__init__()
     self.args = args
@@ -70,7 +69,6 @@ class LlamaModel(nn.Module):
 
 
 class Model(nn.Module):
-
   def __init__(self, args: ModelArgs):
     super().__init__()
     self.args = args

+ 0 - 14
exo/inference/mlx/models/llava.py

@@ -33,7 +33,6 @@ class VisionConfig:
 
 
 class VisionAttention(nn.Module):
-
   def __init__(
     self,
     dims: int,
@@ -86,7 +85,6 @@ class VisionAttention(nn.Module):
 
 
 class VisionMLP(nn.Module):
-
   def __init__(self, config: VisionConfig):
     super().__init__()
     self.activation_fn = nn.GELU(approx="fast")
@@ -100,7 +98,6 @@ class VisionMLP(nn.Module):
 
 
 class VisionEncoderLayer(nn.Module):
-
   def __init__(self, config: VisionConfig):
     super().__init__()
     self.embed_dim = config.hidden_size
@@ -119,14 +116,12 @@ class VisionEncoderLayer(nn.Module):
 
 
 class VisionEncoder(nn.Module):
-
   def __init__(self, config: VisionConfig):
     super().__init__()
     self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
 
 
 class VisionEmbeddings(nn.Module):
-
   def __init__(self, config: VisionConfig):
     super().__init__()
     self.config = config
@@ -160,7 +155,6 @@ class VisionEmbeddings(nn.Module):
 
 
 class ClipVisionModel(nn.Module):
-
   def __init__(self, config: VisionConfig):
     super().__init__()
     self.embeddings = VisionEmbeddings(config)
@@ -188,7 +182,6 @@ class ClipVisionModel(nn.Module):
 
 
 class VisionModel(nn.Module):
-
   def __init__(self, config: VisionConfig):
     super().__init__()
 
@@ -258,7 +251,6 @@ class TextConfig:
 
 
 class TextAttention(nn.Module):
-
   def __init__(self, config: TextConfig):
     super().__init__()
 
@@ -313,7 +305,6 @@ class TextAttention(nn.Module):
 
 
 class TextMLP(nn.Module):
-
   def __init__(self, dim, hidden_dim):
     super().__init__()
     self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
@@ -325,7 +316,6 @@ class TextMLP(nn.Module):
 
 
 class TransformerBlock(nn.Module):
-
   def __init__(self, config: TextConfig):
     super().__init__()
     self.num_attention_heads = config.num_attention_heads
@@ -350,7 +340,6 @@ class TransformerBlock(nn.Module):
 
 
 class Llama(nn.Module):
-
   def __init__(self, config: TextConfig, shard: Shard):
     super().__init__()
     self.config = config
@@ -404,7 +393,6 @@ class Llama(nn.Module):
 
 
 class LanguageModel(nn.Module):
-
   def __init__(self, config: TextConfig, shard: Shard):
     super().__init__()
     self.model_type = config.model_type
@@ -486,7 +474,6 @@ class ModelArgs(LlaVAConfig):
 
 
 class LlavaMultiModalProjector(nn.Module):
-
   def __init__(self, config: LlaVAConfig):
     super().__init__()
     self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
@@ -501,7 +488,6 @@ class LlavaMultiModalProjector(nn.Module):
 
 
 class Model(nn.Module):
-
   def __init__(self, config: ModelArgs):
     super().__init__()
     self.config = config

+ 0 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -9,7 +9,6 @@ from exo.download.shard_download import ShardDownloader
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader

+ 0 - 2
exo/inference/mlx/sharded_model.py

@@ -10,7 +10,6 @@ from ..shard import Shard
 
 
 class StatefulShardedModel:
-
   def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
     self.shard = shard
     self.model = model
@@ -27,7 +26,6 @@ class StatefulShardedModel:
     top_p: float = 1.0,
     logit_bias: Optional[Dict[int, float]] = None,
   ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-
     def sample(logits: mx.array) -> Tuple[mx.array, float]:
       if logit_bias:
         indices = mx.array(list(logit_bias.keys()))

+ 0 - 1
exo/inference/mlx/sharded_utils.py

@@ -25,7 +25,6 @@ from ..shard import Shard
 
 
 class ModelNotFoundError(Exception):
-
   def __init__(self, message):
     self.message = message
     super().__init__(self.message)

+ 0 - 1
exo/inference/mlx/test_sharded_model.py

@@ -6,7 +6,6 @@ import numpy as np
 
 
 class DummyModel(nn.Module):
-
   def __init__(self, shard: Optional[Shard] = None):
     self.shard = shard
     self.layers = [

+ 0 - 1
exo/inference/tinygrad/inference.py

@@ -48,7 +48,6 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
 
 
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader

+ 0 - 5
exo/inference/tinygrad/models/llama.py

@@ -38,7 +38,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
 
 
 class Attention:
-
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
@@ -88,7 +87,6 @@ class Attention:
 
 
 class FeedForward:
-
   def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
@@ -99,7 +97,6 @@ class FeedForward:
 
 
 class TransformerBlock:
-
   def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
@@ -165,7 +162,6 @@ from exo.inference.shard import Shard
 
 
 class Transformer:
-
   def __init__(
     self,
     dim: int,
@@ -222,7 +218,6 @@ class Transformer:
 
 
 def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
-
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 

+ 0 - 1
exo/inference/tinygrad/tinygrad_helpers.py

@@ -11,7 +11,6 @@ from fnmatch import fnmatch
 
 # **** helper functions ****
 def concat_weights(models, device=None):
-
   def convert(name) -> Tensor:
     disk_tensors: List[Tensor] = [model[name] for model in models]
     if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:

+ 0 - 1
exo/networking/discovery.py

@@ -4,7 +4,6 @@ from .peer_handle import PeerHandle
 
 
 class Discovery(ABC):
-
   @abstractmethod
   async def start(self) -> None:
     pass

+ 0 - 2
exo/networking/grpc/grpc_discovery.py

@@ -11,7 +11,6 @@ from exo import DEBUG_DISCOVERY
 
 
 class ListenProtocol(asyncio.DatagramProtocol):
-
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
     self.on_message = on_message
@@ -25,7 +24,6 @@ class ListenProtocol(asyncio.DatagramProtocol):
 
 
 class GRPCDiscovery(Discovery):
-
   def __init__(
     self,
     node_id: str,

+ 0 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -13,7 +13,6 @@ from exo.topology.device_capabilities import DeviceCapabilities
 
 
 class GRPCPeerHandle(PeerHandle):
-
   def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
     self._id = _id
     self.address = address

+ 0 - 1
exo/networking/grpc/grpc_server.py

@@ -11,7 +11,6 @@ from exo.orchestration import Node
 
 
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
-
   def __init__(self, node: Node, host: str, port: int):
     self.node = node
     self.host = host

+ 0 - 3
exo/networking/grpc/node_service_pb2_grpc.py

@@ -27,7 +27,6 @@ if _version_not_supported:
 
 class NodeServiceStub(object):
   """Missing associated documentation comment in .proto file."""
-
   def __init__(self, channel):
     """Constructor.
 
@@ -74,7 +73,6 @@ class NodeServiceStub(object):
 
 class NodeServiceServicer(object):
   """Missing associated documentation comment in .proto file."""
-
   def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
     context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -159,7 +157,6 @@ def add_NodeServiceServicer_to_server(servicer, server):
 # This class is part of an EXPERIMENTAL API.
 class NodeService(object):
   """Missing associated documentation comment in .proto file."""
-
   @staticmethod
   def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
     return grpc.experimental.unary_unary(

+ 0 - 1
exo/networking/grpc/test_grpc_discovery.py

@@ -4,7 +4,6 @@ from .grpc_discovery import GRPCDiscovery
 
 
 class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
-
   async def asyncSetUp(self):
     self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
     self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)

+ 0 - 1
exo/networking/peer_handle.py

@@ -7,7 +7,6 @@ from exo.topology.topology import Topology
 
 
 class PeerHandle(ABC):
-
   @abstractmethod
   def id(self) -> str:
     pass

+ 0 - 1
exo/networking/server.py

@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
 
 
 class Server(ABC):
-
   @abstractmethod
   async def start(self) -> None:
     pass

+ 0 - 1
exo/orchestration/node.py

@@ -7,7 +7,6 @@ from exo.topology.topology import Topology
 
 
 class Node(ABC):
-
   @abstractmethod
   async def start(self, wait_for_peers: int = 0) -> None:
     pass

+ 0 - 2
exo/orchestration/standard_node.py

@@ -18,7 +18,6 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
 
 
 class StandardNode(Node):
-
   def __init__(
     self,
     _id: str,
@@ -360,7 +359,6 @@ class StandardNode(Node):
     self.on_token.trigger_all(request_id, tokens, is_finished)
 
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-
     async def send_result_to_peer(peer):
       try:
         await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)

+ 0 - 1
exo/orchestration/test_node.py

@@ -7,7 +7,6 @@ from exo.networking.peer_handle import PeerHandle
 
 
 class TestNode(unittest.IsolatedAsyncioTestCase):
-
   def setUp(self):
     self.mock_inference_engine = AsyncMock()
     self.mock_server = AsyncMock()

+ 0 - 1
exo/test_callbacks.py

@@ -12,7 +12,6 @@ async def main() -> None:
   callback2 = callback_system.register("callback2")
 
   def on_next_callback(name: str) -> Callable[..., None]:
-
     def callback(*args: Any) -> None:
       print(f"{name} received values: {args}")
 

+ 0 - 1
exo/topology/partitioning_strategy.py

@@ -14,7 +14,6 @@ class Partition:
 
 
 class PartitioningStrategy(ABC):
-
   @abstractmethod
   def partition(self, topology: Topology) -> List[Partition]:
     pass

+ 0 - 1
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -5,7 +5,6 @@ from .partitioning_strategy import Partition
 
 
 class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
-
   def partition(self, topology: Topology) -> List[Partition]:
     nodes = list(topology.all_nodes())
     nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)

+ 0 - 1
exo/topology/test_device_capabilities.py

@@ -4,7 +4,6 @@ from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapa
 
 
 class TestMacDeviceCapabilities(unittest.TestCase):
-
   @patch("subprocess.check_output")
   def test_mac_device_capabilities_pro(self, mock_check_output):
     # Mock the subprocess output

+ 0 - 1
exo/topology/test_map_partitions.py

@@ -5,7 +5,6 @@ from exo.inference.shard import Shard
 
 
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
-
   def test_map_partitions_to_shards(self):
     partitions = [
       Partition("node1", 0.0, 0.42857),

+ 0 - 1
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -6,7 +6,6 @@ from exo.topology.partitioning_strategy import Partition
 
 
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
-
   def test_partition(self):
     # triangle
     # node1 -> node2 -> node3 -> node1

+ 0 - 1
exo/topology/topology.py

@@ -3,7 +3,6 @@ from typing import Dict, Set, Optional
 
 
 class Topology:
-
   def __init__(self):
     self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
     self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph

+ 0 - 1
exo/viz/test_topology_viz.py

@@ -62,7 +62,6 @@ def create_hf_repo_progress_event(
 
 
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
-
   async def asyncSetUp(self):
     self.topology = Topology()
     self.topology.update_node(

+ 0 - 1
exo/viz/topology_viz.py

@@ -18,7 +18,6 @@ from rich.markdown import Markdown
 
 
 class TopologyViz:
-
   def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
     self.chatgpt_api_endpoints = chatgpt_api_endpoints
     self.web_chat_urls = web_chat_urls

+ 0 - 1
extra/download_hf.py

@@ -24,7 +24,6 @@ DEFAULT_IGNORE_PATTERNS = [
 
 
 async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
-
   async def progress_callback(event: RepoProgressEvent):
     print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
     print(f"Estimated time remaining: {event.overall_eta}")