Browse Source

more compact operator formatting

Alex Cheema 10 months ago
parent
commit
f53056dede

+ 6 - 1
.style.yapf

@@ -11,4 +11,9 @@ indent_dictionary_value = True
 allow_multiline_dictionary_keys = True
 allow_multiline_dictionary_keys = True
 each_dict_entry_on_separate_line = False
 each_dict_entry_on_separate_line = False
 allow_multiline_lambdas = True
 allow_multiline_lambdas = True
-blank_line_before_nested_class_or_def = False
+blank_line_before_nested_class_or_def = False
+arithmetic_precedence_indication = True
+no_spaces_around_selected_binary_operators = "*,/"
+coalesce_brackets = True
+space_between_ending_comma_and_closing_bracket = False
+split_before_expression_after_opening_paren = False

+ 3 - 3
exo/api/chatgpt_api.py

@@ -158,7 +158,7 @@ class ChatGPTAPI:
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout_secs = response_timeout_secs
     self.response_timeout_secs = response_timeout_secs
     self.on_chat_completion_request = on_chat_completion_request
     self.on_chat_completion_request = on_chat_completion_request
-    self.app = web.Application(client_max_size=100 * 1024 * 1024)  # 100MB to support image upload
+    self.app = web.Application(client_max_size=100*1024*1024)  # 100MB to support image upload
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
@@ -171,7 +171,7 @@ class ChatGPTAPI:
     )
     )
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
-    self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
+    self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat"
     self.app.router.add_get("/", self.handle_root)
     self.app.router.add_get("/", self.handle_root)
     self.app.router.add_static("/", self.static_dir, name="static")
     self.app.router.add_static("/", self.static_dir, name="static")
 
 
@@ -186,7 +186,7 @@ class ChatGPTAPI:
     return middleware
     return middleware
 
 
   async def handle_root(self, request):
   async def handle_root(self, request):
-    return web.FileResponse(self.static_dir / "index.html")
+    return web.FileResponse(self.static_dir/"index.html")
 
 
   async def handle_post_chat_token_encode(self, request):
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
     data = await request.json()

+ 21 - 21
exo/download/hf/hf_helpers.py

@@ -62,12 +62,12 @@ def _add_wildcard_to_directories(pattern: str) -> str:
 
 
 def get_hf_home() -> Path:
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
   """Get the Hugging Face home directory."""
-  return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface"))
+  return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
 
 
 
 
 async def get_hf_token():
 async def get_hf_token():
   """Retrieve the Hugging Face token from the user's HF_HOME directory."""
   """Retrieve the Hugging Face token from the user's HF_HOME directory."""
-  token_path = get_hf_home() / "token"
+  token_path = get_hf_home()/"token"
   if await aios.path.exists(token_path):
   if await aios.path.exists(token_path):
     async with aiofiles.open(token_path, 'r') as f:
     async with aiofiles.open(token_path, 'r') as f:
       return (await f.read()).strip()
       return (await f.read()).strip()
@@ -85,7 +85,7 @@ async def get_auth_headers():
 def get_repo_root(repo_id: str) -> Path:
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   sanitized_repo_id = repo_id.replace("/", "--")
   sanitized_repo_id = repo_id.replace("/", "--")
-  return get_hf_home() / "hub" / f"models--{sanitized_repo_id}"
+  return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
 
 
 
 async def fetch_file_list(session, repo_id, revision, path=""):
 async def fetch_file_list(session, repo_id, revision, path=""):
@@ -181,9 +181,9 @@ async def download_file(
         downloaded_this_session += len(chunk)
         downloaded_this_session += len(chunk)
         if progress_callback and total_size:
         if progress_callback and total_size:
           elapsed_time = (datetime.now() - start_time).total_seconds()
           elapsed_time = (datetime.now() - start_time).total_seconds()
-          speed = int(downloaded_this_session / elapsed_time) if elapsed_time > 0 else 0
+          speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
           remaining_size = total_size - downloaded_size
           remaining_size = total_size - downloaded_size
-          eta = timedelta(seconds=remaining_size / speed) if speed > 0 else timedelta(0)
+          eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
           status = "in_progress" if downloaded_size < total_size else "complete"
           status = "in_progress" if downloaded_size < total_size else "complete"
           if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
           if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
           await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
           await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
@@ -199,9 +199,9 @@ async def download_repo_files(
   max_parallel_downloads: int = 4
   max_parallel_downloads: int = 4
 ) -> Path:
 ) -> Path:
   repo_root = get_repo_root(repo_id)
   repo_root = get_repo_root(repo_id)
-  refs_dir = repo_root / "refs"
-  snapshots_dir = repo_root / "snapshots"
-  cachedreqs_dir = repo_root / "cachedreqs"
+  refs_dir = repo_root/"refs"
+  snapshots_dir = repo_root/"snapshots"
+  cachedreqs_dir = repo_root/"cachedreqs"
 
 
   # Ensure directories exist
   # Ensure directories exist
   await aios.makedirs(refs_dir, exist_ok=True)
   await aios.makedirs(refs_dir, exist_ok=True)
@@ -209,7 +209,7 @@ async def download_repo_files(
   await aios.makedirs(cachedreqs_dir, exist_ok=True)
   await aios.makedirs(cachedreqs_dir, exist_ok=True)
 
 
   # Check if we have a cached commit hash
   # Check if we have a cached commit hash
-  refs_file = refs_dir / revision
+  refs_file = refs_dir/revision
   if await aios.path.exists(refs_file):
   if await aios.path.exists(refs_file):
     async with aiofiles.open(refs_file, 'r') as f:
     async with aiofiles.open(refs_file, 'r') as f:
       commit_hash = (await f.read()).strip()
       commit_hash = (await f.read()).strip()
@@ -230,13 +230,13 @@ async def download_repo_files(
         await f.write(commit_hash)
         await f.write(commit_hash)
 
 
   # Set up the snapshot directory
   # Set up the snapshot directory
-  snapshot_dir = snapshots_dir / commit_hash
+  snapshot_dir = snapshots_dir/commit_hash
   await aios.makedirs(snapshot_dir, exist_ok=True)
   await aios.makedirs(snapshot_dir, exist_ok=True)
 
 
   # Set up the cached file list directory
   # Set up the cached file list directory
-  cached_file_list_dir = cachedreqs_dir / commit_hash
+  cached_file_list_dir = cachedreqs_dir/commit_hash
   await aios.makedirs(cached_file_list_dir, exist_ok=True)
   await aios.makedirs(cached_file_list_dir, exist_ok=True)
-  cached_file_list_path = cached_file_list_dir / "fetch_file_list.json"
+  cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
 
 
   async with aiohttp.ClientSession() as session:
   async with aiohttp.ClientSession() as session:
     # Check if we have a cached file list
     # Check if we have a cached file list
@@ -261,7 +261,7 @@ async def download_repo_files(
     start_time = datetime.now()
     start_time = datetime.now()
 
 
     async def download_with_progress(file_info, progress_state):
     async def download_with_progress(file_info, progress_state):
-      local_path = snapshot_dir / file_info["path"]
+      local_path = snapshot_dir/file_info["path"]
       if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
       if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
         if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
         if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
         progress_state['completed_files'] += 1
         progress_state['completed_files'] += 1
@@ -269,9 +269,9 @@ async def download_repo_files(
         file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
         file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
         if progress_callback:
         if progress_callback:
           elapsed_time = (datetime.now() - start_time).total_seconds()
           elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
           remaining_bytes = total_bytes - progress_state['downloaded_bytes']
           remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
           status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
           status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
           await progress_callback(
           await progress_callback(
             RepoProgressEvent(
             RepoProgressEvent(
@@ -287,9 +287,9 @@ async def download_repo_files(
         file_progress[event.file_path] = event
         file_progress[event.file_path] = event
         if progress_callback:
         if progress_callback:
           elapsed_time = (datetime.now() - start_time).total_seconds()
           elapsed_time = (datetime.now() - start_time).total_seconds()
-          overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+          overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
           remaining_bytes = total_bytes - progress_state['downloaded_bytes']
           remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-          overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+          overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
           status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
           status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
           await progress_callback(
           await progress_callback(
             RepoProgressEvent(
             RepoProgressEvent(
@@ -305,9 +305,9 @@ async def download_repo_files(
       ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
       ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
       if progress_callback:
       if progress_callback:
         elapsed_time = (datetime.now() - start_time).total_seconds()
         elapsed_time = (datetime.now() - start_time).total_seconds()
-        overall_speed = int(progress_state['downloaded_bytes_this_session'] / elapsed_time) if elapsed_time > 0 else 0
+        overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
         remaining_bytes = total_bytes - progress_state['downloaded_bytes']
         remaining_bytes = total_bytes - progress_state['downloaded_bytes']
-        overall_eta = timedelta(seconds=remaining_bytes / overall_speed) if overall_speed > 0 else timedelta(seconds=0)
+        overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
         status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
         status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
         await progress_callback(
         await progress_callback(
           RepoProgressEvent(
           RepoProgressEvent(
@@ -347,11 +347,11 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[
 
 
   # Check if the file exists
   # Check if the file exists
   repo_root = get_repo_root(repo_id)
   repo_root = get_repo_root(repo_id)
-  snapshot_dir = repo_root / "snapshots"
+  snapshot_dir = repo_root/"snapshots"
   index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
   index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
 
 
   if index_file:
   if index_file:
-    index_file_path = snapshot_dir / index_file
+    index_file_path = snapshot_dir/index_file
     if await aios.path.exists(index_file_path):
     if await aios.path.exists(index_file_path):
       async with aiofiles.open(index_file_path, 'r') as f:
       async with aiofiles.open(index_file_path, 'r') as f:
         index_data = json.loads(await f.read())
         index_data = json.loads(await f.read())

+ 1 - 1
exo/download/hf/hf_shard_download.py

@@ -22,7 +22,7 @@ class HFShardDownloader(ShardDownloader):
       return self.completed_downloads[shard]
       return self.completed_downloads[shard]
     if self.quick_check:
     if self.quick_check:
       repo_root = get_repo_root(shard.model_id)
       repo_root = get_repo_root(shard.model_id)
-      snapshots_dir = repo_root / "snapshots"
+      snapshots_dir = repo_root/"snapshots"
       if snapshots_dir.exists():
       if snapshots_dir.exists():
         most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
         most_recent_dir = max(snapshots_dir.iterdir(), key=lambda x: x.stat().st_mtime)
         return most_recent_dir
         return most_recent_dir

+ 1 - 1
exo/helpers.py

@@ -169,7 +169,7 @@ def is_valid_uuid(val):
 
 
 
 
 def get_or_create_node_id():
 def get_or_create_node_id():
-  NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id"
+  NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__)))/".exo_node_id"
   try:
   try:
     if NODE_ID_FILE.is_file():
     if NODE_ID_FILE.is_file():
       with open(NODE_ID_FILE, "r") as f:
       with open(NODE_ID_FILE, "r") as f:

+ 1 - 1
exo/inference/debug_inference_engine.py

@@ -10,7 +10,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   from exo.inference.tinygrad.inference import Tokenizer
   from exo.inference.tinygrad.inference import Tokenizer
   from pathlib import Path
   from pathlib import Path
 
 
-  _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
+  _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
 
 
   prompt = "In a single word only, what is the last name of the president of the United States? "
   prompt = "In a single word only, what is the last name of the president of the United States? "
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
   resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)

+ 1 - 1
exo/inference/mlx/models/deepseek_v2.py

@@ -59,7 +59,7 @@ class DeepseekV2Model(nn.Module):
       mask = mask.astype(h.dtype)
       mask = mask.astype(h.dtype)
 
 
     if cache is None:
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
 
     for layer, c in zip(self.layers, cache):
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, c)
       h = layer(h, mask, c)

+ 1 - 1
exo/inference/mlx/models/llama.py

@@ -58,7 +58,7 @@ class LlamaModel(nn.Module):
       mask = create_attention_mask(h, cache)
       mask = create_attention_mask(h, cache)
 
 
     if cache is None:
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
 
     for layer, c in zip(self.layers, cache):
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, cache=c)
       h = layer(h, mask, cache=c)

+ 12 - 12
exo/inference/mlx/models/llava.py

@@ -74,8 +74,8 @@ class VisionAttention(nn.Module):
     keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
     keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
     values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
     values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
 
 
-    scale = math.sqrt(1 / queries.shape[-1])
-    scores = (queries * scale) @ keys
+    scale = math.sqrt(1/queries.shape[-1])
+    scores = (queries*scale) @ keys
     if mask is not None:
     if mask is not None:
       scores = scores + mask.astype(scores.dtype)
       scores = scores + mask.astype(scores.dtype)
     scores = mx.softmax(scores, axis=-1)
     scores = mx.softmax(scores, axis=-1)
@@ -129,7 +129,7 @@ class VisionEmbeddings(nn.Module):
     self.image_size = config.image_size
     self.image_size = config.image_size
     self.patch_size = config.patch_size
     self.patch_size = config.patch_size
 
 
-    self.class_embedding = mx.zeros((config.hidden_size, ))
+    self.class_embedding = mx.zeros((config.hidden_size,))
 
 
     self.patch_embedding = nn.Conv2d(
     self.patch_embedding = nn.Conv2d(
       in_channels=config.num_channels,
       in_channels=config.num_channels,
@@ -170,12 +170,12 @@ class ClipVisionModel(nn.Module):
     x = self.embeddings(x)
     x = self.embeddings(x)
     x = self.pre_layrnorm(x)
     x = self.pre_layrnorm(x)
 
 
-    encoder_states = (x, ) if output_hidden_states else None
+    encoder_states = (x,) if output_hidden_states else None
 
 
     for l in self.encoder.layers:
     for l in self.encoder.layers:
       x = l(x, mask=None)
       x = l(x, mask=None)
       if output_hidden_states:
       if output_hidden_states:
-        encoder_states = encoder_states + (x, )
+        encoder_states = encoder_states + (x,)
 
 
     pooler_output = self.post_layernorm(x[:, 0, :])
     pooler_output = self.post_layernorm(x[:, 0, :])
     return pooler_output, x, encoder_states
     return pooler_output, x, encoder_states
@@ -263,12 +263,12 @@ class TextAttention(nn.Module):
     head_dim = config.hidden_size // n_heads
     head_dim = config.hidden_size // n_heads
     self.scale = head_dim**-0.5
     self.scale = head_dim**-0.5
 
 
-    self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
-    self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
-    self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
-    self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
+    self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False)
+    self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
+    self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
+    self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False)
 
 
-    rope_scale = (1 / config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
+    rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
     self.rope = nn.RoPE(
     self.rope = nn.RoPE(
       head_dim,
       head_dim,
       traditional=config.rope_traditional,
       traditional=config.rope_traditional,
@@ -312,7 +312,7 @@ class TextMLP(nn.Module):
     self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
     self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
 
 
   def __call__(self, x) -> mx.array:
   def __call__(self, x) -> mx.array:
-    return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+    return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x))
 
 
 
 
 class TransformerBlock(nn.Module):
 class TransformerBlock(nn.Module):
@@ -382,7 +382,7 @@ class Llama(nn.Module):
       mask = mask.astype(h.dtype)
       mask = mask.astype(h.dtype)
 
 
     if cache is None:
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
 
     for layer, c in zip(self.layers, cache):
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, c)
       h = layer(h, mask, c)

+ 2 - 2
exo/inference/mlx/sharded_model.py

@@ -38,7 +38,7 @@ class StatefulShardedModel:
         if top_p > 0 and top_p < 1.0:
         if top_p > 0 and top_p < 1.0:
           token = top_p_sampling(logits, top_p, temp)
           token = top_p_sampling(logits, top_p, temp)
         else:
         else:
-          token = mx.random.categorical(logits * (1 / temp))
+          token = mx.random.categorical(logits*(1/temp))
 
 
       return token
       return token
 
 
@@ -74,7 +74,7 @@ class StatefulShardedModel:
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
     return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
 
 
   def init_cache(self, request_id: str):
   def init_cache(self, request_id: str):
-    kv_heads = ([self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
+    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
     if self.max_kv_size is not None:
     if self.max_kv_size is not None:
       cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
       cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
     else:
     else:

+ 3 - 3
exo/inference/mlx/sharded_utils.py

@@ -60,7 +60,7 @@ def _get_classes(config: dict):
 
 
 def load_config(model_path: Path) -> dict:
 def load_config(model_path: Path) -> dict:
   try:
   try:
-    with open(model_path / "config.json", "r") as f:
+    with open(model_path/"config.json", "r") as f:
       config = json.load(f)
       config = json.load(f)
   except FileNotFoundError:
   except FileNotFoundError:
     logging.error(f"Config file not found in {model_path}")
     logging.error(f"Config file not found in {model_path}")
@@ -103,11 +103,11 @@ def load_model_shard(
     "n_layers": shard.n_layers,
     "n_layers": shard.n_layers,
   }
   }
 
 
-  weight_files = glob.glob(str(model_path / "model*.safetensors"))
+  weight_files = glob.glob(str(model_path/"model*.safetensors"))
 
 
   if not weight_files:
   if not weight_files:
     # Try weight for back-compat
     # Try weight for back-compat
-    weight_files = glob.glob(str(model_path / "weight*.safetensors"))
+    weight_files = glob.glob(str(model_path/"weight*.safetensors"))
 
 
   if not weight_files:
   if not weight_files:
     logging.error(f"No safetensors found in {model_path}")
     logging.error(f"No safetensors found in {model_path}")

+ 1 - 1
exo/inference/mlx/test_sharded_model.py

@@ -38,7 +38,7 @@ model.save_weights("./test_weights.npz")
 n_layers = 5
 n_layers = 5
 shard1 = Shard("test", 0, n_layers // 2, n_layers)
 shard1 = Shard("test", 0, n_layers // 2, n_layers)
 sharded_model1 = DummyModel(shard1)
 sharded_model1 = DummyModel(shard1)
-shard2 = Shard("test", n_layers // 2 + 1, n_layers - 1, n_layers)
+shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers)
 sharded_model2 = DummyModel(shard2)
 sharded_model2 = DummyModel(shard2)
 
 
 model.load_weights("./test_weights.npz")
 model.load_weights("./test_weights.npz")

+ 5 - 5
exo/inference/tinygrad/inference.py

@@ -33,9 +33,9 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
 
 
   # load weights
   # load weights
   if model_path.is_dir():
   if model_path.is_dir():
-    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"), shard)
-    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"), shard)
-    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
+    if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
+    elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
+    else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
   else:
   else:
     weights = load(str(model_path), shard)
     weights = load(str(model_path), shard)
   weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
   weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
@@ -60,7 +60,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     toks = self.tokenizer.encode(prompt)
     toks = self.tokenizer.encode(prompt)
     h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
     h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
 
 
-    if h.shape == (1, ):
+    if h.shape == (1,):
       start_pos += len(toks)
       start_pos += len(toks)
       start_pos += 1
       start_pos += 1
       n_captured_toks = 0
       n_captured_toks = 0
@@ -76,7 +76,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
 
     h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
     h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
 
 
-    if h.shape == (1, ):
+    if h.shape == (1,):
       start_pos += n_captured_toks
       start_pos += n_captured_toks
       start_pos += 1
       start_pos += 1
       n_captured_toks = 0
       n_captured_toks = 0

+ 18 - 18
exo/inference/tinygrad/models/llama.py

@@ -5,8 +5,8 @@ from tinygrad.helpers import getenv
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta**(Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
-  freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
+  freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
+  freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   # TODO: move dtype outside this
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
 
 
@@ -14,8 +14,8 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtype
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
 def complex_mult(A, c, d):
   a, b = A[..., 0:1], A[..., 1:2]
   a, b = A[..., 0:1], A[..., 1:2]
-  ro = a * c - b * d
-  co = a * d + b * c
+  ro = a*c - b*d
+  co = a*d + b*c
   return ro.cat(co, dim=-1)
   return ro.cat(co, dim=-1)
 
 
 
 
@@ -34,7 +34,7 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
   bs, seqlen, n_kv_heads, head_dim = x.shape
   if n_rep == 1: return x
   if n_rep == 1: return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
-  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
+  return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
 
 
 
 
 class Attention:
 class Attention:
@@ -45,10 +45,10 @@ class Attention:
     self.n_rep = self.n_heads // self.n_kv_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
     self.max_context = max_context
 
 
-    self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
-    self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
-    self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
-    self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
+    self.wq = linear(dim, self.n_heads*self.head_dim, bias=False)
+    self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
+    self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
+    self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
 
 
   def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
   def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
     if getenv("WQKV"):
     if getenv("WQKV"):
@@ -93,7 +93,7 @@ class FeedForward:
     self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
     self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
 
 
   def __call__(self, x: Tensor) -> Tensor:
   def __call__(self, x: Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
+    return self.w2(self.w1(x).silu()*self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
 
 
 class TransformerBlock:
 class TransformerBlock:
@@ -121,29 +121,29 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   if af or ap:
   if af or ap:
     if not hasattr(sample, "alpha_counter"):
     if not hasattr(sample, "alpha_counter"):
       setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
       setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
-    logits = logits - (sample.alpha_counter * af + (sample.alpha_counter > 0) * ap)
+    logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap)
 
 
   # replace NaNs with -inf
   # replace NaNs with -inf
   logits = (logits != logits).where(-float("inf"), logits)
   logits = (logits != logits).where(-float("inf"), logits)
 
 
   # softmax
   # softmax
-  t = (logits / temp).softmax()
+  t = (logits/temp).softmax()
 
 
   counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
   counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
   # top k
   # top k
   if k:
   if k:
     output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
     for i in range(k):
     for i in range(k):
-      t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
-      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1), ))
-      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1), ))
+      t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int)
+      output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
+      output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
       t = (counter == t_argmax).where(0, t)
       t = (counter == t_argmax).where(0, t)
 
 
     # approximate top p
     # approximate top p
     # because we are already limited to top k elements we can do top p "without sorting"
     # because we are already limited to top k elements we can do top p "without sorting"
     output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
     output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
-    output = (output_cumsum >= (1 - p)) * output
-    output_indices = (output_cumsum >= (1 - p)) * output_indices
+    output = (output_cumsum >= (1 - p))*output
+    output_indices = (output_cumsum >= (1 - p))*output_indices
 
 
     # sample
     # sample
     output_idx = output.multinomial()
     output_idx = output.multinomial()
@@ -183,7 +183,7 @@ class Transformer:
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.max_context = max_context
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta).contiguous()
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
 
 

+ 1 - 1
exo/inference/tinygrad/tinygrad_helpers.py

@@ -37,7 +37,7 @@ def load(fn: str, shard: Shard):
         if layer_num < shard.start_layer or layer_num > shard.end_layer:
         if layer_num < shard.start_layer or layer_num > shard.end_layer:
           continue
           continue
 
 
-      parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
+      parts[n] = load(str(Path(fn).parent/Path(n).name), shard)
       filtered_weight_map[k] = n
       filtered_weight_map[k] = n
     if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}
     return {k: parts[n][k] for k, n in filtered_weight_map.items()}

+ 21 - 25
exo/models.py

@@ -2,32 +2,28 @@ from exo.inference.shard import Shard
 
 
 model_base_shards = {
 model_base_shards = {
   ### llama
   ### llama
-  "llama-3.1-8b":
-    {
-      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-      "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
-    },
-  "llama-3.1-70b":
-    {
-      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-      "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
-    },
-  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126), },
-  "llama-3-8b":
-    {
-      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-      "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
-    },
-  "llama-3-70b":
-    {
-      "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-      "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-    },
+  "llama-3.1-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3.1-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
+  "llama-3-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+  },
   ### mistral
   ### mistral
-  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40), },
-  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88), },
+  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
+  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
   ### deepseek v2
   ### deepseek v2
-  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27), },
+  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
   ### llava
   ### llava
-  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32), },
+  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
 }
 }

+ 1 - 1
exo/networking/grpc/grpc_peer_handle.py

@@ -27,7 +27,7 @@ class GRPCPeerHandle(PeerHandle):
     return self._device_capabilities
     return self._device_capabilities
 
 
   async def connect(self):
   async def connect(self):
-    self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32 * 1024 * 1024)])
+    self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
     self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
 
 
   async def is_connected(self) -> bool:
   async def is_connected(self) -> bool:

+ 3 - 3
exo/networking/grpc/grpc_server.py

@@ -21,9 +21,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     self.server = grpc.aio.server(
     self.server = grpc.aio.server(
       futures.ThreadPoolExecutor(max_workers=10),
       futures.ThreadPoolExecutor(max_workers=10),
       options=[
       options=[
-        ("grpc.max_metadata_size", 32 * 1024 * 1024),
-        ("grpc.max_send_message_length", 128 * 1024 * 1024),
-        ("grpc.max_receive_message_length", 128 * 1024 * 1024),
+        ("grpc.max_metadata_size", 32*1024*1024),
+        ("grpc.max_send_message_length", 128*1024*1024),
+        ("grpc.max_receive_message_length", 128*1024*1024),
       ],
       ],
     )
     )
     node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
     node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)

+ 1 - 1
exo/networking/grpc/node_service_pb2_grpc.py

@@ -150,7 +150,7 @@ def add_NodeServiceServicer_to_server(servicer, server):
       ),
       ),
   }
   }
   generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
   generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
-  server.add_generic_rpc_handlers((generic_handler, ))
+  server.add_generic_rpc_handlers((generic_handler,))
   server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
   server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
 
 

+ 46 - 54
exo/orchestration/standard_node.py

@@ -84,19 +84,17 @@ class StandardNode(Node):
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
         request_id,
         request_id,
-        json.dumps(
-          {
-            "type": "node_status",
-            "node_id": self.id,
-            "status": "start_process_prompt",
-            "base_shard": base_shard.to_dict(),
-            "shard": shard.to_dict(),
-            "prompt": prompt,
-            "image_str": image_str,
-            "inference_state": inference_state,
-            "request_id": request_id,
-          }
-        ),
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "start_process_prompt",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "prompt": prompt,
+          "image_str": image_str,
+          "inference_state": inference_state,
+          "request_id": request_id,
+        }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
@@ -106,21 +104,19 @@ class StandardNode(Node):
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
         request_id,
         request_id,
-        json.dumps(
-          {
-            "type": "node_status",
-            "node_id": self.id,
-            "status": "end_process_prompt",
-            "base_shard": base_shard.to_dict(),
-            "shard": shard.to_dict(),
-            "prompt": prompt,
-            "image_str": image_str,
-            "inference_state": inference_state,
-            "request_id": request_id,
-            "elapsed_time_ns": elapsed_time_ns,
-            "result_size": resp.size if resp is not None else 0,
-          }
-        ),
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "end_process_prompt",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "prompt": prompt,
+          "image_str": image_str,
+          "inference_state": inference_state,
+          "request_id": request_id,
+          "elapsed_time_ns": elapsed_time_ns,
+          "result_size": resp.size if resp is not None else 0,
+        }),
       )
       )
     )
     )
     return resp
     return resp
@@ -166,19 +162,17 @@ class StandardNode(Node):
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
         request_id,
         request_id,
-        json.dumps(
-          {
-            "type": "node_status",
-            "node_id": self.id,
-            "status": "start_process_tensor",
-            "base_shard": base_shard.to_dict(),
-            "shard": shard.to_dict(),
-            "tensor_size": tensor.size,
-            "tensor_shape": tensor.shape,
-            "request_id": request_id,
-            "inference_state": inference_state,
-          }
-        ),
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "start_process_tensor",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "tensor_size": tensor.size,
+          "tensor_shape": tensor.shape,
+          "request_id": request_id,
+          "inference_state": inference_state,
+        }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
@@ -188,18 +182,16 @@ class StandardNode(Node):
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
         request_id,
         request_id,
-        json.dumps(
-          {
-            "type": "node_status",
-            "node_id": self.id,
-            "status": "end_process_tensor",
-            "base_shard": base_shard.to_dict(),
-            "shard": shard.to_dict(),
-            "request_id": request_id,
-            "elapsed_time_ns": elapsed_time_ns,
-            "result_size": resp.size if resp is not None else 0,
-          }
-        ),
+        json.dumps({
+          "type": "node_status",
+          "node_id": self.id,
+          "status": "end_process_tensor",
+          "base_shard": base_shard.to_dict(),
+          "shard": shard.to_dict(),
+          "request_id": request_id,
+          "elapsed_time_ns": elapsed_time_ns,
+          "result_size": resp.size if resp is not None else 0,
+        }),
       )
       )
     )
     )
     return resp
     return resp
@@ -257,7 +249,7 @@ class StandardNode(Node):
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
     if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
     if current_partition_index is not None:
     if current_partition_index is not None:
-      next_partition_index = (current_partition_index + 1) % len(partitions)
+      next_partition_index = (current_partition_index+1) % len(partitions)
       next_partition: Partition = partitions[next_partition_index]
       next_partition: Partition = partitions[next_partition_index]
       next_shard = shards[next_partition_index]
       next_shard = shards[next_partition_index]
       if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
       if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")

+ 1 - 1
exo/stats/metrics.py

@@ -24,6 +24,6 @@ def start_metrics_server(node: Node, port: int):
     elif status == "end_process_tensor":
     elif status == "end_process_tensor":
       elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
       elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
       PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
       PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
-      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns / 1e9)  # Convert ns to seconds
+      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9)  # Convert ns to seconds
 
 
   node.on_opaque_status.register("stats").on_next(_on_opaque_status)
   node.on_opaque_status.register("stats").on_next(_on_opaque_status)

+ 63 - 63
exo/topology/device_capabilities.py

@@ -44,78 +44,78 @@ CHIP_FLOPS = {
   # Source: https://www.cpu-monkey.com
   # Source: https://www.cpu-monkey.com
   # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
   # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
   ### M chips
   ### M chips
-  "Apple M1": DeviceFlops(fp32=2.29 * TFLOPS, fp16=4.58 * TFLOPS, int8=9.16 * TFLOPS),
-  "Apple M1 Pro": DeviceFlops(fp32=5.30 * TFLOPS, fp16=10.60 * TFLOPS, int8=21.20 * TFLOPS),
-  "Apple M1 Max": DeviceFlops(fp32=10.60 * TFLOPS, fp16=21.20 * TFLOPS, int8=42.40 * TFLOPS),
-  "Apple M1 Ultra": DeviceFlops(fp32=21.20 * TFLOPS, fp16=42.40 * TFLOPS, int8=84.80 * TFLOPS),
-  "Apple M2": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
-  "Apple M2 Pro": DeviceFlops(fp32=5.68 * TFLOPS, fp16=11.36 * TFLOPS, int8=22.72 * TFLOPS),
-  "Apple M2 Max": DeviceFlops(fp32=13.49 * TFLOPS, fp16=26.98 * TFLOPS, int8=53.96 * TFLOPS),
-  "Apple M2 Ultra": DeviceFlops(fp32=26.98 * TFLOPS, fp16=53.96 * TFLOPS, int8=107.92 * TFLOPS),
-  "Apple M3": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
-  "Apple M3 Max": DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS),
-  "Apple M3 Pro": DeviceFlops(fp32=4.97 * TFLOPS, fp16=9.94 * TFLOPS, int8=19.88 * TFLOPS),
-  "Apple M4": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
+  "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS),
+  "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS),
+  "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS),
+  "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS),
+  "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS),
+  "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
+  "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
+  "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
+  "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
+  "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
   ### A chips
   ### A chips
-  "Apple A13 Bionic": DeviceFlops(fp32=0.69 * TFLOPS, fp16=1.38 * TFLOPS, int8=2.76 * TFLOPS),
-  "Apple A14 Bionic": DeviceFlops(fp32=0.75 * TFLOPS, fp16=1.50 * TFLOPS, int8=3.00 * TFLOPS),
-  "Apple A15 Bionic": DeviceFlops(fp32=1.37 * TFLOPS, fp16=2.74 * TFLOPS, int8=5.48 * TFLOPS),
-  "Apple A16 Bionic": DeviceFlops(fp32=1.79 * TFLOPS, fp16=3.58 * TFLOPS, int8=7.16 * TFLOPS),
-  "Apple A17 Pro": DeviceFlops(fp32=2.15 * TFLOPS, fp16=4.30 * TFLOPS, int8=8.60 * TFLOPS),
+  "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
+  "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
+  "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS),
+  "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS),
+  "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS),
   ### NVIDIA GPUs
   ### NVIDIA GPUs
   # RTX 40 series
   # RTX 40 series
-  "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58 * TFLOPS, fp16=165.16 * TFLOPS, int8=330.32 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74 * TFLOPS, fp16=97.48 * TFLOPS, int8=194.96 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0 * TFLOPS, fp16=104.0 * TFLOPS, int8=208.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43 * TFLOPS, fp16=78.86 * TFLOPS, int8=157.72 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0 * TFLOPS, fp16=60.0 * TFLOPS, int8=120.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0 * TFLOPS, fp16=58.0 * TFLOPS, int8=116.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0 * TFLOPS, fp16=44.0 * TFLOPS, int8=88.0 * TFLOPS),
+  "NVIDIA GEFORCE RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS),
+  "NVIDIA GEFORCE RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS),
+  "NVIDIA GEFORCE RTX 4080 SUPER": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 TI SUPER": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 TI": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
   # RTX 30 series
   # RTX 30 series
-  "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11 * TFLOPS, fp16=18.22 * TFLOPS, int8=36.44 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0 * TFLOPS, fp16=26.0 * TFLOPS, int8=52.0 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3 * TFLOPS, fp16=40.6 * TFLOPS, int8=81.2 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8 * TFLOPS, fp16=43.6 * TFLOPS, int8=87.2 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8 * TFLOPS, fp16=59.6 * TFLOPS, int8=119.2 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6 * TFLOPS, fp16=61.2 * TFLOPS, int8=122.4 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1 * TFLOPS, fp16=68.2 * TFLOPS, int8=136.4 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6 * TFLOPS, fp16=71.2 * TFLOPS, int8=142.4 * TFLOPS),
-  "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
+  "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
+  "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 3060 TI": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
+  "NVIDIA GEFORCE RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3070 TI": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3080 TI": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS),
+  "NVIDIA GEFORCE RTX 3090 TI": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
   # QUATRO RTX Ampere series
   # QUATRO RTX Ampere series
-  "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99 * TFLOPS, fp16=7.99 * TFLOPS, int8=31.91 * TFLOPS),
-  "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17 * TFLOPS, fp16=19.17 * TFLOPS, int8=76.68 * TFLOPS),
-  "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65 * TFLOPS, fp16=23.65 * TFLOPS, int8=94.6 * TFLOPS),
-  "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8 * TFLOPS, fp16=27.8 * TFLOPS, int8=111.2 * TFLOPS),
-  "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71 * TFLOPS, fp16=38.71 * TFLOPS, int8=154.84 * TFLOPS),
+  "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS),
+  "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS),
+  "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS),
+  "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS),
+  "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS),
   # Common Server GPUs
   # Common Server GPUs
-  "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4 * TFLOPS, fp16=149.7 * TFLOPS, int8=299.3 * TFLOPS),
-  "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
-  "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5 * TFLOPS, fp16=312.0 * TFLOPS, int8=624.0 * TFLOPS),
+  "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS),
+  "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A100 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+  "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
   # ... add more devices if needed ...
   # ... add more devices if needed ...
   ### AMD GPUs
   ### AMD GPUs
   # RX 6000 series
   # RX 6000 series
-  "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04 * TFLOPS, fp16=46.08 * TFLOPS, int8=92.16 * TFLOPS),
-  "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74 * TFLOPS, fp16=41.48 * TFLOPS, int8=82.96 * TFLOPS),
-  "AMD Radeon RX 6800": DeviceFlops(fp32=16.17 * TFLOPS, fp16=32.34 * TFLOPS, int8=64.68 * TFLOPS),
-  "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21 * TFLOPS, fp16=26.42 * TFLOPS, int8=52.84 * TFLOPS),
-  "AMD Radeon RX 6700": DeviceFlops(fp32=11.4 * TFLOPS, fp16=22.8 * TFLOPS, int8=45.6 * TFLOPS),
-  "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6 * TFLOPS, fp16=21.2 * TFLOPS, int8=42.4 * TFLOPS),
-  "AMD Radeon RX 6600": DeviceFlops(fp32=8.93 * TFLOPS, fp16=17.86 * TFLOPS, int8=35.72 * TFLOPS),
-  "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77 * TFLOPS, fp16=11.54 * TFLOPS, int8=23.08 * TFLOPS),
-  "AMD Radeon RX 6400": DeviceFlops(fp32=3.57 * TFLOPS, fp16=7.14 * TFLOPS, int8=14.28 * TFLOPS),
+  "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS),
+  "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS),
+  "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS),
+  "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS),
+  "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS),
+  "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS),
+  "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS),
+  "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS),
+  "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS),
   # RX 7000 series
   # RX 7000 series
-  "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4 * TFLOPS, fp16=122.8 * TFLOPS, int8=245.6 * TFLOPS),
-  "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4 * TFLOPS, fp16=106.8 * TFLOPS, int8=213.6 * TFLOPS),
-  "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6 * TFLOPS, fp16=85.2 * TFLOPS, int8=170.4 * TFLOPS),
-  "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2 * TFLOPS, fp16=68.4 * TFLOPS, int8=136.8 * TFLOPS),
-  "AMD Radeon RX 7600": DeviceFlops(fp32=21.5 * TFLOPS, fp16=43.0 * TFLOPS, int8=86.0 * TFLOPS),
-  "AMD Radeon RX 7500": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
+  "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS),
+  "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS),
+  "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS),
+  "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS),
+  "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS),
+  "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
   # ... add more devices if needed ...
   # ... add more devices if needed ...
   ### Qualcomm embedded chips: TODO
   ### Qualcomm embedded chips: TODO
 }
 }
@@ -151,7 +151,7 @@ def mac_device_capabilities() -> DeviceCapabilities:
   memory_units = memory_str.split()
   memory_units = memory_str.split()
   memory_value = int(memory_units[0])
   memory_value = int(memory_units[0])
   if memory_units[1] == "GB":
   if memory_units[1] == "GB":
-    memory = memory_value * 1024
+    memory = memory_value*1024
   else:
   else:
     memory = memory_value
     memory = memory_value
 
 

+ 2 - 2
exo/topology/partitioning_strategy.py

@@ -22,8 +22,8 @@ class PartitioningStrategy(ABC):
 def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
 def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
   shards = []
   shards = []
   for i, partition in enumerate(partitions):
   for i, partition in enumerate(partitions):
-    start_layer = int(partition.start * num_layers)
-    end_layer = int(partition.end * num_layers) - 1
+    start_layer = int(partition.start*num_layers)
+    end_layer = int(partition.end*num_layers) - 1
 
 
     # Ensure the last partition covers up to num_layers - 1
     # Ensure the last partition covers up to num_layers - 1
     if i == len(partitions) - 1:
     if i == len(partitions) - 1:

+ 1 - 1
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -12,7 +12,7 @@ class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
     partitions = []
     partitions = []
     start = 0
     start = 0
     for node in nodes:
     for node in nodes:
-      end = round(start + (node[1].memory / total_memory), 5)
+      end = round(start + (node[1].memory/total_memory), 5)
       partitions.append(Partition(node[0], start, end))
       partitions.append(Partition(node[0], start, end))
       start = end
       start = end
     return partitions
     return partitions

+ 1 - 1
exo/topology/test_device_capabilities.py

@@ -80,7 +80,7 @@ Activation Lock Status: Disabled
     self.assertEqual(result.model, "MacBook Pro")
     self.assertEqual(result.model, "MacBook Pro")
     self.assertEqual(result.chip, "Apple M3 Max")
     self.assertEqual(result.chip, "Apple M3 Max")
     self.assertEqual(result.memory, 131072)  # 128 GB in MB
     self.assertEqual(result.memory, 131072)  # 128 GB in MB
-    self.assertEqual(result.flops, DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS))
+    self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
     self.assertEqual(
     self.assertEqual(
       str(result),
       str(result),
       "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
       "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",

+ 2 - 2
exo/topology/test_map_partitions.py

@@ -56,8 +56,8 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
     def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
     def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
       shards = []
       shards = []
       for i, partition in enumerate(partitions):
       for i, partition in enumerate(partitions):
-        start_layer = int(partition.start * num_layers)
-        end_layer = int(partition.end * num_layers) - 1
+        start_layer = int(partition.start*num_layers)
+        end_layer = int(partition.end*num_layers) - 1
         shards.append(Shard(model_id, start_layer, end_layer, num_layers))
         shards.append(Shard(model_id, start_layer, end_layer, num_layers))
       return shards
       return shards
 
 

+ 3 - 3
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -49,7 +49,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
       DeviceCapabilities(
       DeviceCapabilities(
         model="MacBook Pro",
         model="MacBook Pro",
         chip="test1",
         chip="test1",
-        memory=128 * 1024 * 1024 * 1024,
+        memory=128*1024*1024*1024,
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       ),
       ),
     )
     )
@@ -58,7 +58,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
       DeviceCapabilities(
       DeviceCapabilities(
         model="Mac Studio",
         model="Mac Studio",
         chip="test2",
         chip="test2",
-        memory=192 * 1024 * 1024 * 1024,
+        memory=192*1024*1024*1024,
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       ),
       ),
     )
     )
@@ -67,7 +67,7 @@ class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
       DeviceCapabilities(
       DeviceCapabilities(
         model="MacBook Pro",
         model="MacBook Pro",
         chip="test3",
         chip="test3",
-        memory=128 * 1024 * 1024 * 1024,
+        memory=128*1024*1024*1024,
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
         flops=DeviceFlops(fp32=0, fp16=0, int8=0),
       ),
       ),
     )
     )

+ 4 - 4
exo/viz/test_topology_viz.py

@@ -66,19 +66,19 @@ class TestNodeViz(unittest.IsolatedAsyncioTestCase):
     self.topology = Topology()
     self.topology = Topology()
     self.topology.update_node(
     self.topology.update_node(
       "node1",
       "node1",
-      DeviceCapabilities(model="ModelA", chip="ChipA", memory=8 * 1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)),
+      DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)),
     )
     )
     self.topology.update_node(
     self.topology.update_node(
       "node2",
       "node2",
-      DeviceCapabilities(model="ModelB", chip="ChipB", memory=16 * 1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)),
+      DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)),
     )
     )
     self.topology.update_node(
     self.topology.update_node(
       "node3",
       "node3",
-      DeviceCapabilities(model="ModelC", chip="ChipC", memory=32 * 1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)),
+      DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)),
     )
     )
     self.topology.update_node(
     self.topology.update_node(
       "node4",
       "node4",
-      DeviceCapabilities(model="ModelD", chip="ChipD", memory=64 * 1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)),
+      DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)),
     )
     )
 
 
     self.top_viz = TopologyViz()
     self.top_viz = TopologyViz()

+ 22 - 22
exo/viz/topology_viz.py

@@ -99,7 +99,7 @@ class TopologyViz:
       # Process prompt
       # Process prompt
       prompt_lines = prompt.split('\n')
       prompt_lines = prompt.split('\n')
       if len(prompt_lines) > max_lines // 2:
       if len(prompt_lines) > max_lines // 2:
-        prompt_lines = prompt_lines[:max_lines // 2 - 1] + ['...']
+        prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
       prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
       prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
       prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
       prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
 
 
@@ -139,7 +139,7 @@ class TopologyViz:
     max_line_length = max(len(line) for line in exo_lines)
     max_line_length = max(len(line) for line in exo_lines)
     for i, line in enumerate(exo_lines):
     for i, line in enumerate(exo_lines):
       centered_line = line.center(max_line_length)
       centered_line = line.center(max_line_length)
-      start_x = (100 - max_line_length) // 2 + 15
+      start_x = (100-max_line_length) // 2 + 15
       colored_line = Text(centered_line, style=yellow_style)
       colored_line = Text(centered_line, style=yellow_style)
       for j, char in enumerate(str(colored_line)):
       for j, char in enumerate(str(colored_line)):
         if 0 <= start_x + j < 100 and i < len(visualization):
         if 0 <= start_x + j < 100 and i < len(visualization):
@@ -161,18 +161,18 @@ class TopologyViz:
 
 
     # Calculate total FLOPS and position on the bar
     # Calculate total FLOPS and position on the bar
     total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
     total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
-    bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
+    bar_pos = (math.tanh(total_flops/20 - 2) + 1)/2
 
 
     # Add GPU poor/rich bar
     # Add GPU poor/rich bar
     bar_width = 30
     bar_width = 30
-    bar_start_x = (100 - bar_width) // 2
+    bar_start_x = (100-bar_width) // 2
     bar_y = info_start_y + len(info_lines) + 1
     bar_y = info_start_y + len(info_lines) + 1
 
 
     # Create a gradient bar using emojis
     # Create a gradient bar using emojis
     gradient_bar = Text()
     gradient_bar = Text()
     emojis = ["🟥", "🟧", "🟨", "🟩"]
     emojis = ["🟥", "🟧", "🟨", "🟩"]
     for i in range(bar_width):
     for i in range(bar_width):
-      emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
+      emoji_index = min(int(i/(bar_width/len(emojis))), len(emojis) - 1)
       gradient_bar.append(emojis[emoji_index])
       gradient_bar.append(emojis[emoji_index])
 
 
     # Add the gradient bar to the visualization
     # Add the gradient bar to the visualization
@@ -183,10 +183,10 @@ class TopologyViz:
 
 
     # Add labels
     # Add labels
     visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
     visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
-    visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2:bar_start_x + bar_width * 2 + 11] = "GPU rich"
+    visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = "GPU rich"
 
 
     # Add position indicator and FLOPS value
     # Add position indicator and FLOPS value
-    pos_x = bar_start_x + int(bar_pos * bar_width)
+    pos_x = bar_start_x + int(bar_pos*bar_width)
     flops_str = f"{total_flops:.2f} TFLOPS"
     flops_str = f"{total_flops:.2f} TFLOPS"
     visualization[bar_y - 1][pos_x] = "▼"
     visualization[bar_y - 1][pos_x] = "▼"
     visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
     visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
@@ -198,9 +198,9 @@ class TopologyViz:
     for i, partition in enumerate(self.partitions):
     for i, partition in enumerate(self.partitions):
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
       device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
 
 
-      angle = 2 * math.pi * i / num_partitions
-      x = int(center_x + radius_x * math.cos(angle))
-      y = int(center_y + radius_y * math.sin(angle))
+      angle = 2*math.pi*i/num_partitions
+      x = int(center_x + radius_x*math.cos(angle))
+      y = int(center_y + radius_y*math.sin(angle))
 
 
       # Place node with different color for active node and this node
       # Place node with different color for active node and this node
       if partition.node_id == self.topology.active_node_id:
       if partition.node_id == self.topology.active_node_id:
@@ -220,8 +220,8 @@ class TopologyViz:
       # Calculate info position based on angle
       # Calculate info position based on angle
       info_distance_x = radius_x + 6
       info_distance_x = radius_x + 6
       info_distance_y = radius_y + 3
       info_distance_y = radius_y + 3
-      info_x = int(center_x + info_distance_x * math.cos(angle))
-      info_y = int(center_y + info_distance_y * math.sin(angle))
+      info_x = int(center_x + info_distance_x*math.cos(angle))
+      info_y = int(center_y + info_distance_y*math.sin(angle))
 
 
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
       # Adjust text position to avoid overwriting the node icon and prevent cutoff
       if info_x < x:
       if info_x < x:
@@ -230,9 +230,9 @@ class TopologyViz:
         info_x = min(99 - len(max(node_info, key=len)), info_x)
         info_x = min(99 - len(max(node_info, key=len)), info_x)
 
 
       # Adjust for top and bottom nodes
       # Adjust for top and bottom nodes
-      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:
+      if 5*math.pi/4 < angle < 7*math.pi/4:
         info_x += 4
         info_x += 4
-      elif math.pi / 4 < angle < 3 * math.pi / 4:
+      elif math.pi/4 < angle < 3*math.pi/4:
         info_x += 3
         info_x += 3
         info_y -= 2
         info_y -= 2
 
 
@@ -243,16 +243,16 @@ class TopologyViz:
               visualization[info_y + j][info_x + k] = char
               visualization[info_y + j][info_x + k] = char
 
 
       # Draw line to next node
       # Draw line to next node
-      next_i = (i + 1) % num_partitions
-      next_angle = 2 * math.pi * next_i / num_partitions
-      next_x = int(center_x + radius_x * math.cos(next_angle))
-      next_y = int(center_y + radius_y * math.sin(next_angle))
+      next_i = (i+1) % num_partitions
+      next_angle = 2*math.pi*next_i/num_partitions
+      next_x = int(center_x + radius_x*math.cos(next_angle))
+      next_y = int(center_y + radius_y*math.sin(next_angle))
 
 
       # Simple line drawing
       # Simple line drawing
       steps = max(abs(next_x - x), abs(next_y - y))
       steps = max(abs(next_x - x), abs(next_y - y))
       for step in range(1, steps):
       for step in range(1, steps):
-        line_x = int(x + (next_x - x) * step / steps)
-        line_y = int(y + (next_y - y) * step / steps)
+        line_x = int(x + (next_x-x)*step/steps)
+        line_y = int(y + (next_y-y)*step/steps)
         if 0 <= line_y < 48 and 0 <= line_x < 100:
         if 0 <= line_y < 48 and 0 <= line_x < 100:
           visualization[line_y][line_x] = "-"
           visualization[line_y][line_x] = "-"
 
 
@@ -280,7 +280,7 @@ class TopologyViz:
 
 
       for file_path, file_progress in download_progress.file_progress.items():
       for file_path, file_progress in download_progress.file_progress.items():
         if file_progress.status != "complete":
         if file_progress.status != "complete":
-          progress = int(file_progress.downloaded / file_progress.total * 30)
+          progress = int(file_progress.downloaded/file_progress.total*30)
           bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
           bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
           percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
           percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
           summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
           summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
@@ -294,7 +294,7 @@ class TopologyViz:
         device = self.topology.nodes.get(node_id)
         device = self.topology.nodes.get(node_id)
         partition = next((p for p in self.partitions if p.node_id == node_id), None)
         partition = next((p for p in self.partitions if p.node_id == node_id), None)
         partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
         partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
-        percentage = progress.downloaded_bytes / progress.total_bytes * 100 if progress.total_bytes > 0 else 0
+        percentage = progress.downloaded_bytes/progress.total_bytes*100 if progress.total_bytes > 0 else 0
         speed = pretty_print_bytes_per_second(progress.overall_speed)
         speed = pretty_print_bytes_per_second(progress.overall_speed)
         device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
         device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
         progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
         progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"