Browse Source

increase max line length to 200

Alex Cheema 1 year ago
parent
commit
1dc08fecaa

+ 20 - 66
exo/api/chatgpt_api.py

@@ -15,44 +15,28 @@ from exo.orchestration import Node
 shard_mappings = {
   ### llama
   "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(
-      model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32
-    ),
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
   },
   "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(
-      model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80
-    ),
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
   },
   "llama-3.1-405b": {
-    "MLXDynamicShardInferenceEngine": Shard(
-      model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126
-    ),
+    "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
   },
   "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(
-      model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32
-    ),
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
   },
   "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(
-      model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80
-    ),
-    "TinygradDynamicShardInferenceEngine": Shard(
-      model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80
-    ),
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
   },
   ### mistral
   "mistral-nemo": {
-    "MLXDynamicShardInferenceEngine": Shard(
-      model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40
-    ),
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
   },
   "mistral-large": {
-    "MLXDynamicShardInferenceEngine": Shard(
-      model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88
-    ),
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
   },
   ### deepseek v2
   "deepseek-coder-v2-lite": {
@@ -82,9 +66,7 @@ def resolve_tinygrad_tokenizer(model_id: str):
   elif model_id == "llama3-70b-sfr":
     return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
   else:
-    raise ValueError(
-      f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}"
-    )
+    raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
 
 
 async def resolve_tokenizer(model_id: str):
@@ -190,12 +172,8 @@ class ChatGPTAPI:
       allow_headers="*",
       allow_methods="*",
     )
-    cors.add(
-      self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options}
-    )
-    cors.add(
-      self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}
-    )
+    cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
     self.app.router.add_get("/", self.handle_root)
     self.app.router.add_static("/", self.static_dir, name="static")
@@ -226,22 +204,16 @@ class ChatGPTAPI:
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     chat_request = parse_chat_request(data)
-    if chat_request.model and chat_request.model.startswith(
-      "gpt-"
-    ):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
       chat_request.model = "llama-3.1-8b"
     if not chat_request.model or chat_request.model not in shard_mappings:
       if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
       chat_request.model = "llama-3.1-8b"
     shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
     if not shard:
-      supported_models = [
-        model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines
-      ]
+      supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
       return web.json_response(
-        {
-          "detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"
-        },
+        {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
       )
     request_id = str(uuid.uuid4())
@@ -261,9 +233,7 @@ class ChatGPTAPI:
         import traceback
 
         traceback.print_exc()
-      return web.json_response(
-        {"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500
-      )
+      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
 
     try:
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
@@ -284,11 +254,7 @@ class ChatGPTAPI:
           self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
-          eos_token_id = (
-            tokenizer.special_tokens_map.get("eos_token_id")
-            if isinstance(tokenizer._tokenizer, AutoTokenizer)
-            else tokenizer.eos_token_id
-          )
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
           if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
             new_tokens = new_tokens[:-1]
             if is_finished:
@@ -315,9 +281,7 @@ class ChatGPTAPI:
           return _request_id == request_id and is_finished
 
         _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
-        if (
-          request_id in self.stream_tasks
-        ):  # in case there is still a stream task running, wait for it to complete
+        if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
           if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
           try:
             await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
@@ -332,21 +296,13 @@ class ChatGPTAPI:
         )
 
         finish_reason = "length"
-        eos_token_id = (
-          tokenizer.special_tokens_map.get("eos_token_id")
-          if isinstance(tokenizer._tokenizer, AutoTokenizer)
-          else tokenizer.eos_token_id
-        )
+        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
         if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
         if tokens[-1] == eos_token_id:
           tokens = tokens[:-1]
           finish_reason = "stop"
 
-        return web.json_response(
-          generate_completion(
-            chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"
-          )
-        )
+        return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
     except asyncio.TimeoutError:
       return web.json_response({"detail": "Response generation timed out"}, status=408)
     finally:
@@ -359,7 +315,5 @@ class ChatGPTAPI:
     site = web.TCPSite(runner, host, port)
     await site.start()
     if DEBUG >= 0:
-      print(
-        f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}"
-      )
+      print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
       print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")

+ 3 - 9
exo/inference/debug_inference_engine.py

@@ -6,18 +6,14 @@ import numpy as np
 
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(
-  inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str
-):
+async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   from exo.inference.tinygrad.inference import Tokenizer
   from pathlib import Path
 
   _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
 
   prompt = "In a single word only, what is the last name of the president of the United States? "
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
-    "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
-  )
+  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
@@ -25,9 +21,7 @@ async def test_inference_engine(
     inference_state=inference_state_full,
   )
 
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(
-    "B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt
-  )
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),

+ 2 - 6
exo/inference/inference_engine.py

@@ -7,13 +7,9 @@ from .shard import Shard
 
 class InferenceEngine(ABC):
   @abstractmethod
-  async def infer_tensor(
-    self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None
-  ) -> Tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass
 
   @abstractmethod
-  async def infer_prompt(
-    self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None
-  ) -> Tuple[np.ndarray, str, bool]:
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
     pass

+ 3 - 9
exo/inference/mlx/sharded_inference_engine.py

@@ -11,18 +11,12 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None
 
-  async def infer_prompt(
-    self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None
-  ) -> (np.ndarray, str, bool):
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(
-      self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt)))
-    )
+    output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
-  async def infer_tensor(
-    self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None
-  ) -> (np.ndarray, str, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
     return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id

+ 1 - 5
exo/inference/mlx/sharded_model.py

@@ -61,9 +61,5 @@ class StatefulShardedModel:
     return self.step(x, temp, top_p, logit_bias)
 
   def init_cache(self, request_id: str):
-    kv_heads = (
-      [self.model.n_kv_heads] * len(self.model.layers)
-      if isinstance(self.model.n_kv_heads, int)
-      else self.model.n_kv_heads
-    )
+    kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
     self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]

+ 3 - 9
exo/inference/test_inference_engine.py

@@ -6,13 +6,9 @@ import numpy as np
 
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(
-  inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str
-):
+async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
-    "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
-  )
+  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
@@ -20,9 +16,7 @@ async def test_inference_engine(
     inference_state=inference_state_full,
   )
 
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(
-    "B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt
-  )
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
   resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),

+ 6 - 18
exo/inference/tinygrad/inference.py

@@ -64,9 +64,7 @@ class Tokenizer:
     ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
     self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
 
-    self.model = tiktoken.Encoding(
-      name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens
-    )
+    self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
 
   @property
   def bos_id(self):
@@ -200,9 +198,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None
 
-  async def infer_prompt(
-    self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None
-  ) -> (np.ndarray, str, bool):
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
@@ -211,9 +207,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
     last_tok = toks[-1]
 
-    output_data = np.array(
-      [self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()]
-    )
+    output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
     if output_data.size == 1:
       start_pos += 1
 
@@ -223,15 +217,11 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
       output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens,
     )
 
-  async def infer_tensor(
-    self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None
-  ) -> (np.ndarray, str, bool):
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
-    output_data: np.ndarray = np.array(
-      [self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()]
-    )
+    output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
     if output_data.size == 1:
       start_pos += 1
 
@@ -296,9 +286,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
         # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
         # size = "70B"
       else:
-        raise ValueError(
-          f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}"
-        )
+        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
 
     model = build_transformer(model_path, shard=shard, model_size=size)
     tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))

+ 12 - 51
exo/inference/tinygrad/models/llama.py

@@ -21,9 +21,7 @@ def complex_mult(A, c, d):
 
 
 def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
-  assert (
-    freqs_cis.shape[1] == xq.shape[1] == xk.shape[1]
-  ), f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
+  assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
   assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
@@ -44,9 +42,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
-    self.n_kv_heads = (
-      n_kv_heads if n_kv_heads is not None else n_heads
-    )  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
@@ -67,20 +63,14 @@ class Attention:
 
     # create kv cache
     if not hasattr(self, "cache_kv"):
-      self.cache_kv = (
-        Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
-        .contiguous()
-        .realize()
-      )
+      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
       if isinstance(x.device, tuple):
         # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
         self.cache_kv.shard_((x.device), axis=None).realize()
 
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(
-      Tensor.stack(xk, xv)
-    ).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
     keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
     values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
@@ -198,12 +188,7 @@ class Transformer:
     jit=True,
     feed_forward=FeedForward,
   ):
-    self.layers = [
-      TransformerBlock(
-        dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward
-      )
-      for _ in range(shard.end_layer - shard.start_layer + 1)
-    ]
+    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
@@ -227,13 +212,7 @@ class Transformer:
 
     if self.shard.is_first_layer():
       h = self.tok_embeddings(h)
-    mask = (
-      Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device)
-      .triu(start_pos + 1)
-      .realize()
-      if seqlen > 1
-      else None
-    )
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
     for i, layer in enumerate(self.layers):
       h = layer(h, start_pos, freqs_cis, mask)
@@ -270,32 +249,16 @@ class Transformer:
 # *** helpers ***
 
 
-def convert_from_huggingface(
-  weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard
-):
+def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
   keymap = {
     "model.embed_tokens.weight": "tok_embeddings.weight",
-    **{
-      f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
-      for l in range(len(model.layers))
-    },
-    **{
-      f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
-      for x in ["q", "k", "v", "o"]
-      for l in range(len(model.layers))
-    },
-    **{
-      f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
-      for l in range(len(model.layers))
-    },
-    **{
-      f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
-      for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
-      for l in range(len(model.layers))
-    },
+    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
+    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
+    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
+    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
   }
@@ -324,6 +287,4 @@ def fix_bf16(weights: Dict[Any, Tensor]):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
   # TODO: check if device supports bf16
-  return {
-    k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()
-  }
+  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 5 - 15
exo/networking/grpc/grpc_discovery.py

@@ -78,9 +78,7 @@ class GRPCDiscovery(Discovery):
     while True:
       initial_peer_count = len(self.known_peers)
       if DEBUG_DISCOVERY >= 2:
-        print(
-          f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more..."
-        )
+        print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
       if len(self.known_peers) == initial_peer_count:
         if wait_for_peers > 0:
           await asyncio.sleep(grace_period)
@@ -95,9 +93,7 @@ class GRPCDiscovery(Discovery):
     return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
 
   async def task_broadcast_presence(self):
-    transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-      lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET
-    )
+    transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
     sock = transport.get_extra_info("socket")
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
 
@@ -161,9 +157,7 @@ class GRPCDiscovery(Discovery):
       self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
 
   async def task_listen_for_peers(self):
-    await asyncio.get_event_loop().create_datagram_endpoint(
-      lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port)
-    )
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
     if DEBUG_DISCOVERY >= 2:
       print("Started listen task")
 
@@ -174,16 +168,12 @@ class GRPCDiscovery(Discovery):
         peers_to_remove = [
           peer_handle.id()
           for peer_handle, connected_at, last_seen in self.known_peers.values()
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout)
-          or current_time - last_seen > self.discovery_timeout
+          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
         ]
         if DEBUG_DISCOVERY >= 2:
           print(
             "Peer statuses:",
-            {
-              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
-              for peer_handle, connected_at, last_seen in self.known_peers.values()
-            },
+            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()},
           )
         if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
           print(f"Cleaning up peers: {peers_to_remove}")

+ 4 - 12
exo/networking/grpc/grpc_peer_handle.py

@@ -39,9 +39,7 @@ class GRPCPeerHandle(PeerHandle):
     self.channel = None
     self.stub = None
 
-  async def send_prompt(
-    self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
-  ) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       shard=node_service_pb2.Shard(
@@ -60,9 +58,7 @@ class GRPCPeerHandle(PeerHandle):
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
-  async def send_tensor(
-    self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None
-  ) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -87,9 +83,7 @@ class GRPCPeerHandle(PeerHandle):
     if response.tensor is None:
       return None, response.is_finished
     return (
-      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(
-        response.tensor.shape
-      ),
+      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
       response.is_finished,
     )
 
@@ -98,9 +92,7 @@ class GRPCPeerHandle(PeerHandle):
     response = await self.stub.CollectTopology(request)
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
-      device_capabilities = DeviceCapabilities(
-        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops
-      )
+      device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
       topology.update_node(node_id, device_capabilities)
     for node_id, peers in response.peer_graph.items():
       for peer_id in peers.peer_ids:

+ 4 - 16
exo/networking/grpc/grpc_server.py

@@ -49,11 +49,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     result = await self.node.process_prompt(shard, prompt, request_id)
     if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
-    return (
-      node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
-      if result is not None
-      else node_service_pb2.Tensor()
-    )
+    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
   async def SendTensor(self, request, context):
     shard = Shard(
@@ -62,20 +58,14 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       end_layer=request.shard.end_layer,
       n_layers=request.shard.n_layers,
     )
-    tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(
-      request.tensor.shape
-    )
+    tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
     inference_state = request.inference_state
 
     result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
     if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
-    return (
-      node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
-      if result is not None
-      else node_service_pb2.Tensor()
-    )
+    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
   async def GetInferenceResult(self, request, context):
     request_id = request.request_id
@@ -84,9 +74,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     tensor_data = result[0].tobytes() if result[0] is not None else None
     return (
       node_service_pb2.InferenceResult(
-        tensor=node_service_pb2.Tensor(
-          tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)
-        ),
+        tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
         is_finished=result[1],
       )
       if result[0] is not None

+ 2 - 6
exo/networking/peer_handle.py

@@ -28,15 +28,11 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def send_prompt(
-    self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
-  ) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod
-  async def send_tensor(
-    self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None
-  ) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod

+ 2 - 6
exo/orchestration/node.py

@@ -16,15 +16,11 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def process_prompt(
-    self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
-  ) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod
-  async def process_tensor(
-    self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None
-  ) -> Optional[np.ndarray]:
+  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod

+ 16 - 52
exo/orchestration/standard_node.py

@@ -37,11 +37,7 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-    self.topology_viz = (
-      TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url)
-      if not disable_tui
-      else None
-    )
+    self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) if not disable_tui else None
     self.max_generate_tokens = max_generate_tokens
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
@@ -57,9 +53,7 @@ class StandardNode(Node):
           if status_data.get("node_id") == self.current_topology.active_node_id:
             self.current_topology.active_node_id = None
       if self.topology_viz:
-        self.topology_viz.update_visualization(
-          self.current_topology, self.partitioning_strategy.partition(self.current_topology)
-        )
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
     except json.JSONDecodeError:
       pass
 
@@ -75,9 +69,7 @@ class StandardNode(Node):
     await self.discovery.stop()
     await self.server.stop()
 
-  async def process_prompt(
-    self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
-  ) -> Optional[np.ndarray]:
+  async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
       self.broadcast_opaque_status(
@@ -121,9 +113,7 @@ class StandardNode(Node):
     )
     return resp
 
-  async def _process_prompt(
-    self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
-  ) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     if request_id not in self.buffered_token_output:
@@ -136,15 +126,11 @@ class StandardNode(Node):
       await self.forward_to_next_shard(shard, prompt, request_id)
       return
 
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(
-      request_id, shard, prompt, inference_state=inference_state
-    )
+    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
     is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
     if is_finished:
       self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-    asyncio.create_task(
-      self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)
-    )  # TODO: this is n^2 communication complexity
+    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
 
     if result.size == 1:
       self.buffered_token_output[request_id][0].append(result.item())
@@ -155,11 +141,7 @@ class StandardNode(Node):
     if not is_finished:
       asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
 
-    return (
-      np.array(self.buffered_token_output[request_id][0])
-      if len(self.buffered_token_output[request_id][0]) > 0
-      else None
-    )
+    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
 
   async def process_tensor(
     self,
@@ -225,15 +207,11 @@ class StandardNode(Node):
 
     try:
       if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-      result, inference_state, is_finished = await self.inference_engine.infer_tensor(
-        request_id, shard, tensor, inference_state=inference_state
-      )
+      result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
       is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
       if is_finished:
         self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      asyncio.create_task(
-        self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)
-      )  # TODO: this is n^2 communication complexity
+      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
 
       if result.size == 1:  # we got a new token out
         self.buffered_token_output[request_id][0].append(result.item())
@@ -241,15 +219,9 @@ class StandardNode(Node):
       if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
 
       if not is_finished:
-        asyncio.create_task(
-          self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state)
-        )
-
-      return (
-        np.array(self.buffered_token_output[request_id][0])
-        if len(self.buffered_token_output[request_id][0]) > 0
-        else None
-      )
+        asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
+
+      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       import traceback
@@ -270,9 +242,7 @@ class StandardNode(Node):
     shard = self.get_current_shard(base_shard)
 
     partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(
-      self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id
-    )
+    shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
     if current_partition_index is not None:
@@ -295,13 +265,9 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
 
       if isinstance(tensor_or_prompt, np.ndarray):
-        await target_peer.send_tensor(
-          next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state
-        )
+        await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
       else:
-        await target_peer.send_prompt(
-          next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state
-        )
+        await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
 
   def get_current_shard(self, base_shard: Shard) -> Shard:
     partitions = self.partitioning_strategy.partition(self.topology)
@@ -371,9 +337,7 @@ class StandardNode(Node):
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
     if self.topology_viz:
-      self.topology_viz.update_visualization(
-        self.current_topology, self.partitioning_strategy.partition(self.current_topology)
-      )
+      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
     return next_topology
 
   @property

+ 1 - 3
exo/orchestration/test_node.py

@@ -21,9 +21,7 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
     mock_peer2.id.return_value = "peer2"
     self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
 
-    self.node = StandardNode(
-      "test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery
-    )
+    self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
 
   async def asyncSetUp(self):
     await self.node.start()

+ 2 - 6
exo/topology/device_capabilities.py

@@ -38,9 +38,7 @@ class DeviceCapabilities:
     return {"model": self.model, "chip": self.chip, "memory": self.memory, "flops": self.flops.to_dict()}
 
 
-UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(
-  model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)
-)
+UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
 
 CHIP_FLOPS = {
   # Source: https://www.cpu-monkey.com
@@ -140,9 +138,7 @@ def mac_device_capabilities() -> DeviceCapabilities:
     memory = memory_value
 
   # Assuming static values for other attributes for demonstration
-  return DeviceCapabilities(
-    model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))
-  )
+  return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
 
 
 def linux_device_capabilities() -> DeviceCapabilities:

+ 4 - 12
exo/viz/test_topology_viz.py

@@ -11,27 +11,19 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
     self.topology = Topology()
     self.topology.update_node(
       "node1",
-      DeviceCapabilities(
-        model="ModelA", chip="ChipA", memory=8 * 1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)
-      ),
+      DeviceCapabilities(model="ModelA", chip="ChipA", memory=8 * 1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)),
     )
     self.topology.update_node(
       "node2",
-      DeviceCapabilities(
-        model="ModelB", chip="ChipB", memory=16 * 1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)
-      ),
+      DeviceCapabilities(model="ModelB", chip="ChipB", memory=16 * 1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)),
     )
     self.topology.update_node(
       "node3",
-      DeviceCapabilities(
-        model="ModelC", chip="ChipC", memory=32 * 1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)
-      ),
+      DeviceCapabilities(model="ModelC", chip="ChipC", memory=32 * 1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)),
     )
     self.topology.update_node(
       "node4",
-      DeviceCapabilities(
-        model="ModelD", chip="ChipD", memory=64 * 1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)
-      ),
+      DeviceCapabilities(model="ModelD", chip="ChipD", memory=64 * 1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)),
     )
 
     self.top_viz = TopologyViz()

+ 2 - 7
exo/viz/topology_viz.py

@@ -72,10 +72,7 @@ class TopologyViz:
           visualization[info_start_y + i][start_x + j] = char
 
     # Calculate total FLOPS and position on the bar
-    total_flops = sum(
-      self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16
-      for partition in self.partitions
-    )
+    total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
     bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
 
     # Add GPU poor/rich bar
@@ -104,9 +101,7 @@ class TopologyViz:
     pos_x = bar_start_x + int(bar_pos * bar_width)
     flops_str = f"{total_flops:.2f} TFLOPS"
     visualization[bar_y - 1][pos_x] = "▼"
-    visualization[bar_y + 1][
-      pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2
-    ] = flops_str
+    visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 2][pos_x] = "▲"
 
     for i, partition in enumerate(self.partitions):

+ 1 - 1
format.py

@@ -59,7 +59,7 @@ def run_black(target):
     exclude_patterns = '|'.join(f'({pattern.replace("*", ".*")})' for pattern in IGNORE_PATTERNS)
     command = [
         "black",
-        "--line-length", "120",
+        "--line-length", "200",
         "--extend-exclude", exclude_patterns,
         target
     ]

+ 4 - 4
pyproject.toml

@@ -1,17 +1,17 @@
 [tool.black]
-line-length = 120
+line-length = 200
 indent-size = 2
 skip-string-normalization = true
 
 [tool.isort]
 profile = "black"
-line_length = 120
+line_length = 200
 indent = "  "
 
 [tool.pylint.format]
 indent-string = '  '
-max-line-length = 120
+max-line-length = 200
 
 [tool.autopep8]
-max_line_length = 120
+max_line_length = 200
 indent_size = 2