Browse Source

formatting / linting

Alex Cheema 11 months ago
parent
commit
ce761038ac
42 changed files with 3828 additions and 2435 deletions
  1. 472 0
      .pylintrc
  2. 318 250
      exo/api/chatgpt_api.py
  3. 111 101
      exo/helpers.py
  4. 51 21
      exo/inference/debug_inference_engine.py
  5. 11 6
      exo/inference/inference_engine.py
  6. 290 311
      exo/inference/mlx/models/sharded_llama.py
  7. 23 16
      exo/inference/mlx/sharded_inference_engine.py
  8. 51 50
      exo/inference/mlx/sharded_model.py
  9. 200 187
      exo/inference/mlx/sharded_utils.py
  10. 5 5
      exo/inference/mlx/test_sharded_llama.py
  11. 29 27
      exo/inference/mlx/test_sharded_model.py
  12. 16 15
      exo/inference/shard.py
  13. 42 12
      exo/inference/test_inference_engine.py
  14. 190 91
      exo/inference/tinygrad/inference.py
  15. 146 43
      exo/inference/tinygrad/models/llama.py
  16. 1 1
      exo/networking/__init__.py
  17. 10 9
      exo/networking/discovery.py
  18. 182 137
      exo/networking/grpc/grpc_discovery.py
  19. 102 81
      exo/networking/grpc/grpc_peer_handle.py
  20. 106 65
      exo/networking/grpc/grpc_server.py
  21. 14 13
      exo/networking/grpc/test_grpc_discovery.py
  22. 44 39
      exo/networking/peer_handle.py
  23. 7 6
      exo/networking/server.py
  24. 43 38
      exo/orchestration/node.py
  25. 405 266
      exo/orchestration/standard_node.py
  26. 51 48
      exo/orchestration/test_node.py
  27. 19 17
      exo/stats/metrics.py
  28. 33 30
      exo/test_callbacks.py
  29. 156 130
      exo/topology/device_capabilities.py
  30. 23 20
      exo/topology/partitioning_strategy.py
  31. 12 11
      exo/topology/ring_memory_weighted_partitioning_strategy.py
  32. 72 64
      exo/topology/test_device_capabilities.py
  33. 68 55
      exo/topology/test_map_partitions.py
  34. 83 42
      exo/topology/test_ring_memory_weighted_partitioning_strategy.py
  35. 45 44
      exo/topology/topology.py
  36. 57 29
      exo/viz/test_topology_viz.py
  37. 160 154
      exo/viz/topology_viz.py
  38. 105 0
      format.py
  39. 5 0
      lint.sh
  40. 17 0
      pyproject.toml
  41. 43 0
      ruff.toml
  42. 10 1
      setup.py

+ 472 - 0
.pylintrc

@@ -0,0 +1,472 @@
+[MASTER]
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code
+extension-pkg-whitelist=scipy,cereal.messaging.messaging_pyx,PyQt5,av
+
+# Add files or directories to the blacklist. They should be base names, not
+# paths.
+ignore=CVS
+
+# Add files or directories matching the regex patterns to the blacklist. The
+# regex matches against base names, not paths.
+ignore-patterns=.*node_service_pb2.*
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint.
+jobs=4
+
+# List of plugins (as comma separated values of python modules names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Specify a configuration file.
+#rcfile=
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
+confidence=
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once).You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use"--disable=all --enable=classes
+# --disable=W"
+disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401
+# E1101 for function binding
+# W0221 for Function class
+# W0105 for comment strings
+# E0401 for missing imports
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member,use-a-generator, no-else-return
+
+
+[REPORTS]
+
+# Python expression which should return a note less than 10 (10 is the highest
+# note). You have access to the variables errors warning, statement which
+# respectively contain the number of errors / warnings messages and the total
+# number of statements analyzed. This is used by the global evaluation report
+# (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details
+#msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio).You can also give a reporter class, eg
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages
+reports=no
+
+# Activate the evaluation score.
+score=yes
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=optparse.Values,sys.exit
+
+
+[LOGGING]
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format
+logging-modules=logging
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it working
+# install python-enchant package.
+spelling-dict=
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to indicated private dictionary in
+# --spelling-private-dict-file option instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+      XXX,
+      TODO
+
+
+[SIMILARITIES]
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=capnp.* cereal.* pygame.* zmq.* setproctitle.* smbus2.* usb1.* serial.* cv2.* ft4222.* carla.*
+
+# Tells whether missing members accessed in mixin class should be ignored. A
+# mixin class is detected if its name ends with "mixin" (case insensitive).
+ignore-mixin-members=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis. It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=flask setproctitle usb1 flask.ext.socketio smbus2 usb1.*
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid to define new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+          _cb
+
+# A regular expression matching the name of dummy variables (i.e. expectedly
+# not used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored. Default to name
+# with leading underscore
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )?<?https?://\S+>?$
+
+# Number of spaces of indent required inside a hanging  or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
+# tab).
+indent-string='  '
+
+# Maximum number of characters on a single line.
+max-line-length=150
+
+# Maximum number of lines in a module
+max-module-lines=1000
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[BASIC]
+
+# Naming style matching correct argument names
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style
+#argument-rgx=
+
+# Naming style matching correct attribute names
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma
+bad-names=foo,
+          bar,
+          baz,
+          toto,
+          tutu,
+          tata
+
+# Naming style matching correct class attribute names
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style
+#class-attribute-rgx=
+
+# Naming style matching correct class names
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-style
+#class-rgx=
+
+# Naming style matching correct constant names
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma
+good-names=i,
+           j,
+           k,
+           ex,
+           Run,
+           _
+
+# Include a hint for the correct naming format with invalid-name
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style
+#inlinevar-rgx=
+
+# Naming style matching correct method names
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style
+#method-rgx=
+
+# Naming style matching correct module names
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+property-classes=abc.abstractproperty
+
+# Naming style matching correct variable names
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style
+#variable-rgx=
+
+
+[DESIGN]
+
+# Maximum number of arguments for function / method
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in a if statement
+max-bool-expr=5
+
+# Maximum number of branch for function / method body
+max-branches=12
+
+# Maximum number of locals for function / method body
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body
+max-returns=6
+
+# Maximum number of statements in function / method body
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[CLASSES]
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+                      __new__,
+                      setUp
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+                  _fields,
+                  _replace,
+                  _source,
+                  _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[IMPORTS]
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Deprecated modules which should not be used, separated by a comma
+deprecated-modules=regsub,
+                   TERMIOS,
+                   Bastion,
+                   rexec
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled)
+ext-import-graph=
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled)
+import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled)
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+[STRING]
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=yes
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "Exception"
+overgeneral-exceptions=builtins.Exception

+ 318 - 250
exo/api/chatgpt_api.py

@@ -13,279 +13,347 @@ from exo.inference.shard import Shard
 from exo.orchestration import Node
 
 shard_mappings = {
-    ### llama
-    "llama-3.1-8b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    },
-    "llama-3.1-70b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    },
-    "llama-3.1-405b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
-    },
-    "llama-3-8b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
-    },
-    "llama-3-70b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
-    },
-    ### mistral
-    "mistral-nemo": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
-    },
-    "mistral-large": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
-    },
+  ### llama
+  "llama-3.1-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(
+      model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32
+    ),
+  },
+  "llama-3.1-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(
+      model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80
+    ),
+  },
+  "llama-3.1-405b": {
+    "MLXDynamicShardInferenceEngine": Shard(
+      model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126
+    ),
+  },
+  "llama-3-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(
+      model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32
+    ),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(
+      model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80
+    ),
+    "TinygradDynamicShardInferenceEngine": Shard(
+      model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80
+    ),
+  },
+  ### mistral
+  "mistral-nemo": {
+    "MLXDynamicShardInferenceEngine": Shard(
+      model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40
+    ),
+  },
+  "mistral-large": {
+    "MLXDynamicShardInferenceEngine": Shard(
+      model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88
+    ),
+  },
 }
 
+
 class Message:
-    def __init__(self, role: str, content: str):
-        self.role = role
-        self.content = content
+  def __init__(self, role: str, content: str):
+    self.role = role
+    self.content = content
+
 
 class ChatCompletionRequest:
-    def __init__(self, model: str, messages: List[Message], temperature: float):
-        self.model = model
-        self.messages = messages
-        self.temperature = temperature
+  def __init__(self, model: str, messages: List[Message], temperature: float):
+    self.model = model
+    self.messages = messages
+    self.temperature = temperature
+
 
 def resolve_tinygrad_tokenizer(model_id: str):
-    if model_id == "llama3-8b-sfr":
-        return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-    elif model_id == "llama3-70b-sfr":
-        return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-    else:
-        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
+  if model_id == "llama3-8b-sfr":
+    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
+  elif model_id == "llama3-70b-sfr":
+    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
+  else:
+    raise ValueError(
+      f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}"
+    )
+
 
 async def resolve_tokenizer(model_id: str):
-    try:
-        if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
-        return AutoTokenizer.from_pretrained(model_id)
-    except:
-        import traceback
-        if DEBUG >= 2: print(traceback.format_exc())
-        if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
+  try:
+    if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
+    return AutoTokenizer.from_pretrained(model_id)
+  except:
+    import traceback
 
-    try:
-        if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
-        return resolve_tinygrad_tokenizer(model_id)
-    except:
-        import traceback
-        if DEBUG >= 2: print(traceback.format_exc())
-        if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
+    if DEBUG >= 2: print(traceback.format_exc())
+    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
+
+  try:
+    if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
+    return resolve_tinygrad_tokenizer(model_id)
+  except:
+    import traceback
+
+    if DEBUG >= 2: print(traceback.format_exc())
+    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
+
+  if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
+  from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
+
+  return load_tokenizer(await get_model_path(model_id))
 
-    if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
-    from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-    return load_tokenizer(await get_model_path(model_id))
 
 def generate_completion(
-        chat_request: ChatCompletionRequest,
-        tokenizer,
-        prompt: str,
-        request_id: str,
-        tokens: List[int],
-        stream: bool,
-        finish_reason: Union[Literal["length", "stop"], None],
-        object_type: Literal["chat.completion", "text_completion"]
-    ) -> dict:
-    completion = {
-        "id": f"chatcmpl-{request_id}",
-        "object": object_type,
-        "created": int(time.time()),
-        "model": chat_request.model,
-        "system_fingerprint": f"exo_{VERSION}",
-        "choices": [
-            {
-                "index": 0,
-                "message": {
-                    "role": "assistant",
-                    "content": tokenizer.decode(tokens)
-                },
-                "logprobs": None,
-                "finish_reason": finish_reason,
-            }
-        ]
+  chat_request: ChatCompletionRequest,
+  tokenizer,
+  prompt: str,
+  request_id: str,
+  tokens: List[int],
+  stream: bool,
+  finish_reason: Union[Literal["length", "stop"], None],
+  object_type: Literal["chat.completion", "text_completion"],
+) -> dict:
+  completion = {
+    "id": f"chatcmpl-{request_id}",
+    "object": object_type,
+    "created": int(time.time()),
+    "model": chat_request.model,
+    "system_fingerprint": f"exo_{VERSION}",
+    "choices": [
+      {
+        "index": 0,
+        "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
+        "logprobs": None,
+        "finish_reason": finish_reason,
+      }
+    ],
+  }
+
+  if not stream:
+    completion["usage"] = {
+      "prompt_tokens": len(tokenizer.encode(prompt)),
+      "completion_tokens": len(tokens),
+      "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
     }
 
-    if not stream:
-        completion["usage"] = {
-            "prompt_tokens": len(tokenizer.encode(prompt)),
-            "completion_tokens": len(tokens),
-            "total_tokens": len(tokenizer.encode(prompt)) + len(tokens)
-        }
+  choice = completion["choices"][0]
+  if object_type.startswith("chat.completion"):
+    key_name = "delta" if stream else "message"
+    choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
+  elif object_type == "text_completion":
+    choice["text"] = tokenizer.decode(tokens)
+  else:
+    ValueError(f"Unsupported response type: {object_type}")
 
-    choice = completion["choices"][0]
-    if object_type.startswith("chat.completion"):
-        key_name = "delta" if stream else "message"
-        choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
-    elif object_type == "text_completion":
-        choice['text'] = tokenizer.decode(tokens)
-    else:
-        ValueError(f"Unsupported response type: {object_type}")
+  return completion
 
-    return completion
 
 def build_prompt(tokenizer, messages: List[Message]):
-    return tokenizer.apply_chat_template(
-        messages, tokenize=False, add_generation_prompt=True
-    )
+  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
 
 def parse_message(data: dict):
-    if 'role' not in data or 'content' not in data:
-        raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
-    return Message(data['role'], data['content'])
+  if "role" not in data or "content" not in data:
+    raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
+  return Message(data["role"], data["content"])
+
 
 def parse_chat_request(data: dict):
-    return ChatCompletionRequest(
-        data.get('model', 'llama-3.1-8b'),
-        [parse_message(msg) for msg in data['messages']],
-        data.get('temperature', 0.0)
-    )
+  return ChatCompletionRequest(
+    data.get("model", "llama-3.1-8b"),
+    [parse_message(msg) for msg in data["messages"]],
+    data.get("temperature", 0.0),
+  )
+
 
 class ChatGPTAPI:
-    def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
-        self.node = node
-        self.inference_engine_classname = inference_engine_classname
-        self.response_timeout_secs = response_timeout_secs
-        self.app = web.Application()
-        self.prev_token_lens: Dict[str, int] = {}
-        self.stream_tasks: Dict[str, asyncio.Task] = {}
-        cors = aiohttp_cors.setup(self.app)
-        cors_options = aiohttp_cors.ResourceOptions(
-            allow_credentials=True,
-            expose_headers="*",
-            allow_headers="*",
-            allow_methods="*",
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
+    self.node = node
+    self.inference_engine_classname = inference_engine_classname
+    self.response_timeout_secs = response_timeout_secs
+    self.app = web.Application()
+    self.prev_token_lens: Dict[str, int] = {}
+    self.stream_tasks: Dict[str, asyncio.Task] = {}
+    cors = aiohttp_cors.setup(self.app)
+    cors_options = aiohttp_cors.ResourceOptions(
+      allow_credentials=True,
+      expose_headers="*",
+      allow_headers="*",
+      allow_methods="*",
+    )
+    cors.add(
+      self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options}
+    )
+    cors.add(
+      self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options}
+    )
+    self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
+    self.app.router.add_get("/", self.handle_root)
+    self.app.router.add_static("/", self.static_dir, name="static")
+
+    # Add middleware to log every request
+    self.app.middlewares.append(self.log_request)
+
+  async def log_request(self, app, handler):
+    async def middleware(request):
+      if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
+      return await handler(request)
+
+    return middleware
+
+  async def handle_root(self, request):
+    print(f"Handling root request from {request.remote}")
+    return web.FileResponse(self.static_dir / "index.html")
+
+  async def handle_post_chat_token_encode(self, request):
+    data = await request.json()
+    shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    messages = [parse_message(msg) for msg in data.get("messages", [])]
+    tokenizer = await resolve_tokenizer(shard.model_id)
+    return web.json_response({"length": len(build_prompt(tokenizer, messages))})
+
+  async def handle_post_chat_completions(self, request):
+    data = await request.json()
+    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    stream = data.get("stream", False)
+    chat_request = parse_chat_request(data)
+    if chat_request.model and chat_request.model.startswith(
+      "gpt-"
+    ):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
+      chat_request.model = "llama-3.1-8b"
+    if not chat_request.model or chat_request.model not in shard_mappings:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
+      chat_request.model = "llama-3.1-8b"
+    shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
+    if not shard:
+      supported_models = [
+        model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines
+      ]
+      return web.json_response(
+        {
+          "detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"
+        },
+        status=400,
+      )
+    request_id = str(uuid.uuid4())
+
+    tokenizer = await resolve_tokenizer(shard.model_id)
+    if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
+
+    prompt = build_prompt(tokenizer, chat_request.messages)
+    callback_id = f"chatgpt-api-wait-response-{request_id}"
+    callback = self.node.on_token.register(callback_id)
+
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+    try:
+      await self.node.process_prompt(shard, prompt, request_id=request_id)
+    except Exception as e:
+      if DEBUG >= 2:
+        import traceback
+
+        traceback.print_exc()
+      return web.json_response(
+        {"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500
+      )
+
+    try:
+      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
+
+      if stream:
+        response = web.StreamResponse(
+          status=200,
+          reason="OK",
+          headers={
+            "Content-Type": "application/json",
+            "Cache-Control": "no-cache",
+          },
+        )
+        await response.prepare(request)
+
+        async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
+          prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
+          self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
+          new_tokens = tokens[prev_last_tokens_len:]
+          finish_reason = None
+          eos_token_id = (
+            tokenizer.special_tokens_map.get("eos_token_id")
+            if isinstance(tokenizer._tokenizer, AutoTokenizer)
+            else tokenizer.eos_token_id
+          )
+          if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
+            new_tokens = new_tokens[:-1]
+            if is_finished:
+              finish_reason = "stop"
+          if is_finished and not finish_reason:
+            finish_reason = "length"
+
+          completion = generate_completion(
+            chat_request,
+            tokenizer,
+            prompt,
+            request_id,
+            new_tokens,
+            stream,
+            finish_reason,
+            "chat.completion",
+          )
+          if DEBUG >= 2: print(f"Streaming completion: {completion}")
+          await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+
+        def on_result(_request_id: str, tokens: List[int], is_finished: bool):
+          self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+
+          return _request_id == request_id and is_finished
+
+        _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
+        if (
+          request_id in self.stream_tasks
+        ):  # in case there is still a stream task running, wait for it to complete
+          if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
+          try:
+            await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
+          except asyncio.TimeoutError:
+            print("WARNING: Stream task timed out. This should not happen.")
+        await response.write_eof()
+        return response
+      else:
+        _, tokens, _ = await callback.wait(
+          lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
+          timeout=self.response_timeout_secs,
+        )
+
+        finish_reason = "length"
+        eos_token_id = (
+          tokenizer.special_tokens_map.get("eos_token_id")
+          if isinstance(tokenizer._tokenizer, AutoTokenizer)
+          else tokenizer.eos_token_id
         )
-        cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post_chat_completions), {
-            "*": cors_options
-        })
-        cors.add(self.app.router.add_post('/v1/chat/token/encode', self.handle_post_chat_token_encode), {
-            "*": cors_options
-        })
-        self.static_dir = Path(__file__).parent.parent.parent / 'tinychat/examples/tinychat'
-        self.app.router.add_get('/', self.handle_root)
-        self.app.router.add_static('/', self.static_dir, name='static')
-
-        # Add middleware to log every request
-        self.app.middlewares.append(self.log_request)
-
-    async def log_request(self, app, handler):
-        async def middleware(request):
-            if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
-            return await handler(request)
-        return middleware
-
-    async def handle_root(self, request):
-        print(f"Handling root request from {request.remote}")
-        return web.FileResponse(self.static_dir / 'index.html')
-
-    async def handle_post_chat_token_encode(self, request):
-        data = await request.json()
-        shard = shard_mappings.get(data.get('model', 'llama-3.1-8b'), {}).get(self.inference_engine_classname)
-        messages = [parse_message(msg) for msg in data.get('messages', [])]
-        tokenizer = await resolve_tokenizer(shard.model_id)
-        return web.json_response({'length': len(build_prompt(tokenizer, messages))})
-
-    async def handle_post_chat_completions(self, request):
-        data = await request.json()
-        if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
-        stream = data.get('stream', False)
-        chat_request = parse_chat_request(data)
-        if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-            chat_request.model = "llama-3.1-8b"
-        if not chat_request.model or chat_request.model not in shard_mappings:
-            if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
-            chat_request.model = "llama-3.1-8b"
-        shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
-        if not shard:
-            supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
-            return web.json_response({'detail': f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"}, status=400)
-        request_id = str(uuid.uuid4())
-
-        tokenizer = await resolve_tokenizer(shard.model_id)
-        if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
-
-        prompt = build_prompt(tokenizer, chat_request.messages)
-        callback_id = f"chatgpt-api-wait-response-{request_id}"
-        callback = self.node.on_token.register(callback_id)
-
-        if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
-        try:
-            await self.node.process_prompt(shard, prompt, request_id=request_id)
-        except Exception as e:
-            if DEBUG >= 2:
-                import traceback
-                traceback.print_exc()
-            return web.json_response({'detail': f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-
-        try:
-            if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
-
-            if stream:
-                response = web.StreamResponse(
-                    status=200,
-                    reason="OK",
-                    headers={
-                        "Content-Type": "application/json",
-                        "Cache-Control": "no-cache",
-                    }
-                )
-                await response.prepare(request)
-
-                async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
-                    prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
-                    self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
-                    new_tokens = tokens[prev_last_tokens_len:]
-                    finish_reason = None
-                    eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
-                    if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
-                        new_tokens = new_tokens[:-1]
-                        if is_finished:
-                            finish_reason = "stop"
-                    if is_finished and not finish_reason:
-                        finish_reason = "length"
-
-                    completion = generate_completion(chat_request, tokenizer, prompt, request_id, new_tokens, stream, finish_reason, "chat.completion")
-                    if DEBUG >= 2: print(f"Streaming completion: {completion}")
-                    await response.write(f"data: {json.dumps(completion)}\n\n".encode())
-                def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-                    self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
-
-                    return _request_id == request_id and is_finished
-                _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
-                if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
-                    if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
-                    try:
-                        await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
-                    except asyncio.TimeoutError:
-                        print("WARNING: Stream task timed out. This should not happen.")
-                await response.write_eof()
-                return response
-            else:
-                _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
-
-                finish_reason = "length"
-                eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
-                if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
-                if tokens[-1] == eos_token_id:
-                    tokens = tokens[:-1]
-                    finish_reason = "stop"
-
-                return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
-        except asyncio.TimeoutError:
-            return web.json_response({'detail': "Response generation timed out"}, status=408)
-        finally:
-            deregistered_callback = self.node.on_token.deregister(callback_id)
-            if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
-
-    async def run(self, host: str = "0.0.0.0", port: int = 8000):
-        runner = web.AppRunner(self.app)
-        await runner.setup()
-        site = web.TCPSite(runner, host, port)
-        await site.start()
-        if DEBUG >= 0:
-            print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
-            print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")
+        if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
+        if tokens[-1] == eos_token_id:
+          tokens = tokens[:-1]
+          finish_reason = "stop"
+
+        return web.json_response(
+          generate_completion(
+            chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"
+          )
+        )
+    except asyncio.TimeoutError:
+      return web.json_response({"detail": "Response generation timed out"}, status=408)
+    finally:
+      deregistered_callback = self.node.on_token.deregister(callback_id)
+      if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
+
+  async def run(self, host: str = "0.0.0.0", port: int = 8000):
+    runner = web.AppRunner(self.app)
+    await runner.setup()
+    site = web.TCPSite(runner, host, port)
+    await site.start()
+    if DEBUG >= 0:
+      print(
+        f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}"
+      )
+      print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")

+ 111 - 101
exo/helpers.py

@@ -18,124 +18,134 @@ exo_text = r"""
     """
 
 def get_system_info():
-    if psutil.MACOS:
-        if platform.machine() == 'arm64':
-            return "Apple Silicon Mac"
-        elif platform.machine() in ['x86_64', 'i386']:
-            return "Intel Mac"
-        else:
-            return "Unknown Mac architecture"
-    elif psutil.LINUX:
-        return "Linux"
+  if psutil.MACOS:
+    if platform.machine() == "arm64":
+      return "Apple Silicon Mac"
+    elif platform.machine() in ["x86_64", "i386"]:
+      return "Intel Mac"
     else:
-        return "Non-Mac, non-Linux system"
+      return "Unknown Mac architecture"
+  elif psutil.LINUX:
+    return "Linux"
+  else:
+    return "Non-Mac, non-Linux system"
+
 
 def get_inference_engine(inference_engine_name):
-    if inference_engine_name == "mlx":
-        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-        return MLXDynamicShardInferenceEngine()
-    elif inference_engine_name == "tinygrad":
-        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-        return TinygradDynamicShardInferenceEngine()
-    else:
-        raise ValueError(f"Inference engine {inference_engine_name} not supported")
-
-def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
-    used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
-
-    def read_used_ports():
-        if os.path.exists(used_ports_file):
-            with open(used_ports_file, 'r') as f:
-                return [int(line.strip()) for line in f if line.strip().isdigit()]
-        return []
-
-    def write_used_port(port, used_ports):
-        with open(used_ports_file, 'w') as f:
-            print(used_ports[-19:])
-            for p in used_ports[-19:] + [port]:
-                f.write(f"{p}\n")
-
-    used_ports = read_used_ports()
-    available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
-
-    while available_ports:
-        port = random.choice(list(available_ports))
-        if DEBUG >= 2: print(f"Trying to find available port {port=}")
-        try:
-            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-                s.bind((host, port))
-            write_used_port(port, used_ports)
-            return port
-        except socket.error:
-            available_ports.remove(port)
-
-    raise RuntimeError("No available ports in the specified range")
+  if inference_engine_name == "mlx":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+
+    return MLXDynamicShardInferenceEngine()
+  elif inference_engine_name == "tinygrad":
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+
+    return TinygradDynamicShardInferenceEngine()
+  else:
+    raise ValueError(f"Inference engine {inference_engine_name} not supported")
+
+
+def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
+  used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
+
+  def read_used_ports():
+    if os.path.exists(used_ports_file):
+      with open(used_ports_file, "r") as f:
+        return [int(line.strip()) for line in f if line.strip().isdigit()]
+    return []
+
+  def write_used_port(port, used_ports):
+    with open(used_ports_file, "w") as f:
+      print(used_ports[-19:])
+      for p in used_ports[-19:] + [port]:
+        f.write(f"{p}\n")
+
+  used_ports = read_used_ports()
+  available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
+
+  while available_ports:
+    port = random.choice(list(available_ports))
+    if DEBUG >= 2: print(f"Trying to find available port {port=}")
+    try:
+      with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        s.bind((host, port))
+      write_used_port(port, used_ports)
+      return port
+    except socket.error:
+      available_ports.remove(port)
+
+  raise RuntimeError("No available ports in the specified range")
+
 
 def print_exo():
-    print(exo_text)
+  print(exo_text)
+
 
 def print_yellow_exo():
-    yellow = "\033[93m"  # ANSI escape code for yellow
-    reset = "\033[0m"    # ANSI escape code to reset color
-    print(f"{yellow}{exo_text}{reset}")
+  yellow = "\033[93m"  # ANSI escape code for yellow
+  reset = "\033[0m"  # ANSI escape code to reset color
+  print(f"{yellow}{exo_text}{reset}")
+
 
 def terminal_link(uri, label=None):
-    if label is None: 
-        label = uri
-    parameters = ''
+  if label is None:
+    label = uri
+  parameters = ""
+
+  # OSC 8 ; params ; URI ST <name> OSC 8 ;; ST
+  escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\"
+
+  return escape_mask.format(parameters, uri, label)
 
-    # OSC 8 ; params ; URI ST <name> OSC 8 ;; ST 
-    escape_mask = '\033]8;{};{}\033\\{}\033]8;;\033\\'
 
-    return escape_mask.format(parameters, uri, label)
+T = TypeVar("T")
+K = TypeVar("K")
 
-T = TypeVar('T')
-K = TypeVar('K')
 
 class AsyncCallback(Generic[T]):
-    def __init__(self) -> None:
-        self.condition: asyncio.Condition = asyncio.Condition()
-        self.result: Optional[Tuple[T, ...]] = None
-        self.observers: list[Callable[..., None]] = []
-
-    async def wait(self,
-                   check_condition: Callable[..., bool],
-                   timeout: Optional[float] = None) -> Tuple[T, ...]:
-        async with self.condition:
-            await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
-            assert self.result is not None  # for type checking
-            return self.result
-
-    def on_next(self, callback: Callable[..., None]) -> None:
-        self.observers.append(callback)
-
-    def set(self, *args: T) -> None:
-        self.result = args
-        for observer in self.observers:
-            observer(*args)
-        asyncio.create_task(self.notify())
-
-    async def notify(self) -> None:
-        async with self.condition:
-            self.condition.notify_all()
+  def __init__(self) -> None:
+    self.condition: asyncio.Condition = asyncio.Condition()
+    self.result: Optional[Tuple[T, ...]] = None
+    self.observers: list[Callable[..., None]] = []
+
+  async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
+    async with self.condition:
+      await asyncio.wait_for(
+        self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout
+      )
+      assert self.result is not None  # for type checking
+      return self.result
+
+  def on_next(self, callback: Callable[..., None]) -> None:
+    self.observers.append(callback)
+
+  def set(self, *args: T) -> None:
+    self.result = args
+    for observer in self.observers:
+      observer(*args)
+    asyncio.create_task(self.notify())
+
+  async def notify(self) -> None:
+    async with self.condition:
+      self.condition.notify_all()
+
 
 class AsyncCallbackSystem(Generic[K, T]):
-    def __init__(self) -> None:
-        self.callbacks: Dict[K, AsyncCallback[T]] = {}
+  def __init__(self) -> None:
+    self.callbacks: Dict[K, AsyncCallback[T]] = {}
 
-    def register(self, name: K) -> AsyncCallback[T]:
-        if name not in self.callbacks:
-            self.callbacks[name] = AsyncCallback[T]()
-        return self.callbacks[name]
+  def register(self, name: K) -> AsyncCallback[T]:
+    if name not in self.callbacks:
+      self.callbacks[name] = AsyncCallback[T]()
+    return self.callbacks[name]
 
-    def deregister(self, name: K) -> None:
-        if name in self.callbacks:
-            del self.callbacks[name]
+  def deregister(self, name: K) -> None:
+    if name in self.callbacks:
+      del self.callbacks[name]
 
-    def trigger(self, name: K, *args: T) -> None:
-        if name in self.callbacks:
-            self.callbacks[name].set(*args)
+  def trigger(self, name: K, *args: T) -> None:
+    if name in self.callbacks:
+      self.callbacks[name].set(*args)
 
-    def trigger_all(self, *args: T) -> None:
-        for callback in self.callbacks.values():
-            callback.set(*args)
+  def trigger_all(self, *args: T) -> None:
+    for callback in self.callbacks.values():
+      callback.set(*args)

+ 51 - 21
exo/inference/debug_inference_engine.py

@@ -5,34 +5,64 @@ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import asyncio
 import numpy as np
 
+
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
-    from exo.inference.tinygrad.inference import Tokenizer
-    from pathlib import Path
-    _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
+async def test_inference_engine(
+  inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str
+):
+  from exo.inference.tinygrad.inference import Tokenizer
+  from pathlib import Path
+
+  _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
 
-    prompt = "In a single word only, what is the last name of the president of the United States? "
-    resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-    next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
+  prompt = "In a single word only, what is the last name of the president of the United States? "
+  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
+    "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
+  )
+  next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+    "A",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
+    input_data=resp_full,
+    inference_state=inference_state_full,
+  )
 
-    resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-    resp2, inference_state_2, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state_1)
-    resp3, inference_state_3, _ = await inference_engine_1.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=resp2, inference_state=inference_state_2)
-    resp4, inference_state_4, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, inference_state=inference_state_3)
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(
+    "B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt
+  )
+  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp1,
+    inference_state=inference_state_1,
+  )
+  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
+    input_data=resp2,
+    inference_state=inference_state_2,
+  )
+  resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp3,
+    inference_state=inference_state_3,
+  )
 
-    print(f"{resp2=}")
-    print(f"full: {_tokenizer.decode(resp_full)}")
-    print(f"next full: {_tokenizer.decode(next_resp_full)}")
-    print(f"resp2: {_tokenizer.decode(resp2)}")
-    print(f"{resp4=}")
-    print(f"resp4: {_tokenizer.decode(resp4)}")
+  print(f"{resp2=}")
+  print(f"full: {_tokenizer.decode(resp_full)}")
+  print(f"next full: {_tokenizer.decode(next_resp_full)}")
+  print(f"resp2: {_tokenizer.decode(resp2)}")
+  print(f"{resp4=}")
+  print(f"resp4: {_tokenizer.decode(resp4)}")
 
-    assert np.array_equal(resp_full, resp2)
-    assert np.array_equal(next_resp_full, resp4)
+  assert np.array_equal(resp_full, resp2)
+  assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(test_inference_engine(
+asyncio.run(
+  test_inference_engine(
     TinygradDynamicShardInferenceEngine(),
     TinygradDynamicShardInferenceEngine(),
     "llama3-8b-sfr",
-))
+  )
+)

+ 11 - 6
exo/inference/inference_engine.py

@@ -4,11 +4,16 @@ from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from .shard import Shard
 
+
 class InferenceEngine(ABC):
-    @abstractmethod
-    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-        pass
+  @abstractmethod
+  async def infer_tensor(
+    self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None
+  ) -> Tuple[np.ndarray, str, bool]:
+    pass
 
-    @abstractmethod
-    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-        pass
+  @abstractmethod
+  async def infer_prompt(
+    self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None
+  ) -> Tuple[np.ndarray, str, bool]:
+    pass

+ 290 - 311
exo/inference/mlx/models/sharded_llama.py

@@ -7,345 +7,324 @@ import mlx.nn as nn
 from exo.inference.shard import Shard
 from mlx_lm.models.base import BaseModelArgs, KVCache, create_additive_causal_mask
 
+
 @dataclass
 class NormalModelArgs(BaseModelArgs):
-    model_type: str
-    hidden_size: int
-    num_hidden_layers: int
-    intermediate_size: int
-    num_attention_heads: int
-    rms_norm_eps: float
-    vocab_size: int
-    head_dim: Optional[int] = None
-    max_position_embeddings: Optional[int] = None
-    num_key_value_heads: Optional[int] = None
-    attention_bias: bool = False
-    mlp_bias: bool = False
-    rope_theta: float = 10000
-    rope_traditional: bool = False
-    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-    tie_word_embeddings: bool = True
-
-    def __post_init__(self):
-        if self.num_key_value_heads is None:
-            self.num_key_value_heads = self.num_attention_heads
-
-        if self.rope_scaling:
-            if not "factor" in self.rope_scaling:
-                raise ValueError(f"rope_scaling must contain 'factor'")
-            rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
-                "rope_type"
-            )
-            if rope_type is None:
-                raise ValueError(
-                    f"rope_scaling must contain either 'type' or 'rope_type'"
-                )
-            if rope_type not in ["linear", "dynamic", "llama3"]:
-                raise ValueError(
-                    "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
-                )
+  model_type: str
+  hidden_size: int
+  num_hidden_layers: int
+  intermediate_size: int
+  num_attention_heads: int
+  rms_norm_eps: float
+  vocab_size: int
+  head_dim: Optional[int] = None
+  max_position_embeddings: Optional[int] = None
+  num_key_value_heads: Optional[int] = None
+  attention_bias: bool = False
+  mlp_bias: bool = False
+  rope_theta: float = 10000
+  rope_traditional: bool = False
+  rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+  tie_word_embeddings: bool = True
+
+  def __post_init__(self):
+    if self.num_key_value_heads is None:
+      self.num_key_value_heads = self.num_attention_heads
+
+    if self.rope_scaling:
+      if not "factor" in self.rope_scaling:
+        raise ValueError(f"rope_scaling must contain 'factor'")
+      rope_type = self.rope_scaling.get("type") or self.rope_scaling.get("rope_type")
+      if rope_type is None:
+        raise ValueError(f"rope_scaling must contain either 'type' or 'rope_type'")
+      if rope_type not in ["linear", "dynamic", "llama3"]:
+        raise ValueError("rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'")
+
 
 @dataclass
 class ModelArgs(NormalModelArgs):
-    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()  # Ensure parent initializations are respected
 
-    def __post_init__(self):
-        super().__post_init__()  # Ensure parent initializations are respected
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
 
-        if isinstance(self.shard, Shard):
-            return
-        if not isinstance(self.shard, dict):
-            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+    self.shard = Shard(**self.shard)
 
-        self.shard = Shard(**self.shard)
 
 class DynamicNTKScalingRoPE(nn.Module):
-    """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
-
-    def __init__(
-        self,
-        dims: int,
-        max_position_embeddings: int = 2048,
-        traditional: bool = False,
-        base: float = 10000,
-        scale: float = 1.0,
-        rope_type: str = "default",
-        rope_scaling: dict = None,
-    ):
-        super().__init__()
-        self.dims = dims
-        self.max_position_embeddings = max_position_embeddings
-        self.traditional = traditional
-        self.original_base = base
-        self.scale = scale
-        self.rope_type = rope_type
-        self.rope_scaling = rope_scaling
-        self.base = self.compute_base_freq()
-
-    def compute_base_freq(self):
-        if self.rope_type == "llama3":
-            return self.compute_llama3_base_freq()
-        return self.original_base
-
-    # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
-    def compute_llama3_base_freq(self):
-        factor = self.rope_scaling["factor"]
-        low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
-        high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
-        old_context_len = self.rope_scaling.get(
-            "original_max_position_embeddings",
-            8192,
-        )
-
-        low_freq_wavelen = old_context_len / low_freq_factor
-        high_freq_wavelen = old_context_len / high_freq_factor
-
-        freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
-        wavelens = 2 * mx.pi * freqs
-        new_base_freqs = []
-
-        smooths = (wavelens - high_freq_wavelen) / (
-            low_freq_wavelen - high_freq_wavelen
-        )
-        new_base_freqs = freqs * (1 - smooths) * factor + smooths
-        new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
-        new_base_freqs = mx.where(
-            wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
-        )
-        return new_base_freqs.mean().item()
-
-    def extra_repr(self):
-        return (
-            f"{self.dims}, traditional={self.traditional}, "
-            f"max_position_embeddings={self.max_position_embeddings}, "
-            f"scaling_factor={self.scale}, rope_type={self.rope_type}"
-        )
-
-    def __call__(self, x, offset: int = 0):
-        seq_len = x.shape[1] + offset
-        base = self.base
-        if self.max_position_embeddings and seq_len > self.max_position_embeddings:
-            base *= (
-                (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
-            ) ** (self.dims / (self.dims - 2))
-
-        return mx.fast.rope(
-            x,
-            self.dims,
-            traditional=self.traditional,
-            base=base,
-            scale=self.scale,
-            offset=offset,
-        )
+  """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
+
+  def __init__(
+    self,
+    dims: int,
+    max_position_embeddings: int = 2048,
+    traditional: bool = False,
+    base: float = 10000,
+    scale: float = 1.0,
+    rope_type: str = "default",
+    rope_scaling: dict = None,
+  ):
+    super().__init__()
+    self.dims = dims
+    self.max_position_embeddings = max_position_embeddings
+    self.traditional = traditional
+    self.original_base = base
+    self.scale = scale
+    self.rope_type = rope_type
+    self.rope_scaling = rope_scaling
+    self.base = self.compute_base_freq()
+
+  def compute_base_freq(self):
+    if self.rope_type == "llama3":
+      return self.compute_llama3_base_freq()
+    return self.original_base
+
+  # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
+  def compute_llama3_base_freq(self):
+    factor = self.rope_scaling["factor"]
+    low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
+    high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
+    old_context_len = self.rope_scaling.get(
+      "original_max_position_embeddings",
+      8192,
+    )
 
+    low_freq_wavelen = old_context_len / low_freq_factor
+    high_freq_wavelen = old_context_len / high_freq_factor
 
-def initialize_rope(args: ModelArgs):
-    head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
-
-    rope_scaling = args.rope_scaling
-    rope_type = "default"
-    rope_scale = 1.0
-
-    if rope_scaling is not None:
-        rope_type = (
-            rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
-        )
-        if rope_type == "linear":
-            rope_scale = 1 / rope_scaling["factor"]
-        elif rope_type == "llama3":
-            rope_scale = 1.0  # The scaling is handled internally for llama3
-
-    return DynamicNTKScalingRoPE(
-        dims=head_dim,
-        max_position_embeddings=args.max_position_embeddings,
-        traditional=args.rope_traditional,
-        base=args.rope_theta,
-        scale=rope_scale,
-        rope_type=rope_type,
-        rope_scaling=rope_scaling,
+    freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
+    wavelens = 2 * mx.pi * freqs
+    new_base_freqs = []
+
+    smooths = (wavelens - high_freq_wavelen) / (low_freq_wavelen - high_freq_wavelen)
+    new_base_freqs = freqs * (1 - smooths) * factor + smooths
+    new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
+    new_base_freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, new_base_freqs)
+    return new_base_freqs.mean().item()
+
+  def extra_repr(self):
+    return (
+      f"{self.dims}, traditional={self.traditional}, "
+      f"max_position_embeddings={self.max_position_embeddings}, "
+      f"scaling_factor={self.scale}, rope_type={self.rope_type}"
+    )
+
+  def __call__(self, x, offset: int = 0):
+    seq_len = x.shape[1] + offset
+    base = self.base
+    if self.max_position_embeddings and seq_len > self.max_position_embeddings:
+      base *= ((self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)) ** (
+        self.dims / (self.dims - 2)
+      )
+
+    return mx.fast.rope(
+      x,
+      self.dims,
+      traditional=self.traditional,
+      base=base,
+      scale=self.scale,
+      offset=offset,
     )
 
 
+def initialize_rope(args: ModelArgs):
+  head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
+
+  rope_scaling = args.rope_scaling
+  rope_type = "default"
+  rope_scale = 1.0
+
+  if rope_scaling is not None:
+    rope_type = rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
+    if rope_type == "linear":
+      rope_scale = 1 / rope_scaling["factor"]
+    elif rope_type == "llama3":
+      rope_scale = 1.0  # The scaling is handled internally for llama3
+
+  return DynamicNTKScalingRoPE(
+    dims=head_dim,
+    max_position_embeddings=args.max_position_embeddings,
+    traditional=args.rope_traditional,
+    base=args.rope_theta,
+    scale=rope_scale,
+    rope_type=rope_type,
+    rope_scaling=rope_scaling,
+  )
+
+
 class Attention(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-
-        dim = args.hidden_size
-        self.n_heads = n_heads = args.num_attention_heads
-        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
-
-        self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
-
-        self.scale = head_dim**-0.5
-        if hasattr(args, "attention_bias"):
-            attention_bias = args.attention_bias
-        else:
-            attention_bias = False
-
-        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
-        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
-        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
-        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
-
-        self.rope = initialize_rope(args)
-
-    def __call__(
-        self,
-        x: mx.array,
-        mask: Optional[mx.array] = None,
-        cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        B, L, D = x.shape
-
-        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
-
-        # Prepare the queries, keys and values for the attention computation
-        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
-        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-
-        if cache is not None:
-            queries = self.rope(queries, offset=cache.offset)
-            keys = self.rope(keys, offset=cache.offset)
-            keys, values = cache.update_and_fetch(keys, values)
-        else:
-            queries = self.rope(queries)
-            keys = self.rope(keys)
-
-        output = mx.fast.scaled_dot_product_attention(
-            queries, keys, values, scale=self.scale, mask=mask
-        )
-        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
-        return self.o_proj(output)
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+
+    dim = args.hidden_size
+    self.n_heads = n_heads = args.num_attention_heads
+    self.n_kv_heads = n_kv_heads = args.num_key_value_heads
+
+    self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
+
+    self.scale = head_dim**-0.5
+    if hasattr(args, "attention_bias"):
+      attention_bias = args.attention_bias
+    else:
+      attention_bias = False
+
+    self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
+    self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
+    self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
+    self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
+
+    self.rope = initialize_rope(args)
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    B, L, D = x.shape
+
+    queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+
+    # Prepare the queries, keys and values for the attention computation
+    queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+    keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+    values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+
+    if cache is not None:
+      queries = self.rope(queries, offset=cache.offset)
+      keys = self.rope(keys, offset=cache.offset)
+      keys, values = cache.update_and_fetch(keys, values)
+    else:
+      queries = self.rope(queries)
+      keys = self.rope(keys)
+
+    output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
+    output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+    return self.o_proj(output)
 
 
 class MLP(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
+  def __init__(self, args: ModelArgs):
+    super().__init__()
 
-        dim = args.hidden_size
-        hidden_dim = args.intermediate_size
-        if hasattr(args, "mlp_bias"):
-            mlp_bias = args.mlp_bias
-        else:
-            mlp_bias = False
+    dim = args.hidden_size
+    hidden_dim = args.intermediate_size
+    if hasattr(args, "mlp_bias"):
+      mlp_bias = args.mlp_bias
+    else:
+      mlp_bias = False
 
-        self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
-        self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
-        self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
+    self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
+    self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
+    self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
 
-    def __call__(self, x) -> mx.array:
-        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
+  def __call__(self, x) -> mx.array:
+    return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
 
 
 class TransformerBlock(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-        self.num_attention_heads = args.num_attention_heads
-        self.hidden_size = args.hidden_size
-        self.self_attn = Attention(args)
-        self.mlp = MLP(args)
-        self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-        self.post_attention_layernorm = nn.RMSNorm(
-            args.hidden_size, eps=args.rms_norm_eps
-        )
-        self.args = args
-
-    def __call__(
-        self,
-        x: mx.array,
-        mask: Optional[mx.array] = None,
-        cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        r = self.self_attn(self.input_layernorm(x), mask, cache)
-        h = x + r
-        r = self.mlp(self.post_attention_layernorm(h))
-        out = h + r
-        return out
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.num_attention_heads = args.num_attention_heads
+    self.hidden_size = args.hidden_size
+    self.self_attn = Attention(args)
+    self.mlp = MLP(args)
+    self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+    self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+    self.args = args
+
+  def __call__(
+    self,
+    x: mx.array,
+    mask: Optional[mx.array] = None,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    r = self.self_attn(self.input_layernorm(x), mask, cache)
+    h = x + r
+    r = self.mlp(self.post_attention_layernorm(h))
+    out = h + r
+    return out
 
 
 class LlamaModel(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-        self.args = args
-        self.vocab_size = args.vocab_size
-        self.num_hidden_layers = args.num_hidden_layers
-        assert self.vocab_size > 0
-        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
-        self.layers = [
-            TransformerBlock(args=args) for _ in range(args.shard.end_layer - args.shard.start_layer + 1)
-        ]
-        self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-
-    def __call__(
-        self,
-        inputs: mx.array,
-        cache=None,
-    ):
-        if self.args.shard.is_first_layer():
-            h = self.embed_tokens(inputs)
-        else:
-            h = inputs
-
-        mask = None
-        if h.shape[1] > 1:
-            mask = create_additive_causal_mask(
-                h.shape[1], cache[0].offset if cache is not None else 0
-            )
-            mask = mask.astype(h.dtype)
-
-        if cache is None:
-            cache = [None] * len(self.layers)
-
-        for layer, c in zip(self.layers, cache):
-            h = layer(h, mask, cache=c)
-
-        if self.args.shard.is_last_layer():
-            return self.norm(h)
-        else:
-            return h
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = [TransformerBlock(args=args) for _ in range(args.shard.end_layer - args.shard.start_layer + 1)]
+    self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = create_additive_causal_mask(h.shape[1], cache[0].offset if cache is not None else 0)
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      return self.norm(h)
+    else:
+      return h
 
 
 class Model(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-        self.args = args
-        self.model_type = args.model_type
-        self.model = LlamaModel(args)
-        if not args.tie_word_embeddings:
-            self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
-
-    def __call__(
-        self,
-        inputs: mx.array,
-        cache=None,
-    ):
-        out = self.model(inputs, cache)
-
-        if self.args.shard.is_last_layer():
-            if self.args.tie_word_embeddings:
-                out = self.model.embed_tokens.as_linear(out)
-            else:
-                out = self.lm_head(out)
-
-        return out
-
-    def sanitize(self, weights):
-        # Remove unused precomputed rotary freqs
-        return {
-            k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
-        }
-
-    @property
-    def layers(self):
-        return self.model.layers
-
-    @property
-    def head_dim(self):
-        return (
-            self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
-        )
-
-    @property
-    def n_kv_heads(self):
-        return self.args.num_key_value_heads
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = LlamaModel(args)
+    if not args.tie_word_embeddings:
+      self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+
+    if self.args.shard.is_last_layer():
+      if self.args.tie_word_embeddings:
+        out = self.model.embed_tokens.as_linear(out)
+      else:
+        out = self.lm_head(out)
+
+    return out
+
+  def sanitize(self, weights):
+    # Remove unused precomputed rotary freqs
+    return {k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k}
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 23 - 16
exo/inference/mlx/sharded_inference_engine.py

@@ -6,24 +6,31 @@ from .sharded_utils import load_shard
 from ..shard import Shard
 from typing import Optional
 
+
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-    def __init__(self):
-        self.shard = None
+  def __init__(self):
+    self.shard = None
 
-    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        await self.ensure_shard(shard)
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
-        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+  async def infer_prompt(
+    self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None
+  ) -> (np.ndarray, str, bool):
+    await self.ensure_shard(shard)
+    output_data: np.ndarray = np.array(
+      self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt)))
+    )
+    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
-    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        await self.ensure_shard(shard)
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
-        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+  async def infer_tensor(
+    self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None
+  ) -> (np.ndarray, str, bool):
+    await self.ensure_shard(shard)
+    output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
+    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
-    async def ensure_shard(self, shard: Shard):
-        if self.shard == shard:
-            return
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard:
+      return
 
-        model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
-        self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
-        self.shard = shard
+    model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
+    self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
+    self.shard = shard

+ 51 - 50
exo/inference/mlx/sharded_model.py

@@ -7,62 +7,63 @@ from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
 
+
 class StatefulShardedModel:
-    def __init__(self, shard: Shard, model: nn.Module):
-        self.shard = shard
-        self.model = model
-        self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
+  def __init__(self, shard: Shard, model: nn.Module):
+    self.shard = shard
+    self.model = model
+    self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
 
-    def step(
-        self,
-        request_id: str,
-        x,
-        temp: float = 0.0,
-        top_p: float = 1.0,
-        logit_bias: Optional[Dict[int, float]] = None,
-    ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-        def sample(logits: mx.array) -> Tuple[mx.array, float]:
-            if logit_bias:
-                indices = mx.array(list(logit_bias.keys()))
-                values = mx.array(list(logit_bias.values()))
-                logits[:, indices] += values
+  def step(
+    self,
+    request_id: str,
+    x,
+    temp: float = 0.0,
+    top_p: float = 1.0,
+    logit_bias: Optional[Dict[int, float]] = None,
+  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
+    def sample(logits: mx.array) -> Tuple[mx.array, float]:
+      if logit_bias:
+        indices = mx.array(list(logit_bias.keys()))
+        values = mx.array(list(logit_bias.values()))
+        logits[:, indices] += values
 
-            if temp == 0:
-                token = mx.argmax(logits, axis=-1)
-            else:
-                if top_p > 0 and top_p < 1.0:
-                    token = top_p_sampling(logits, top_p, temp)
-                else:
-                    token = mx.random.categorical(logits * (1 / temp))
+      if temp == 0:
+        token = mx.argmax(logits, axis=-1)
+      else:
+        if top_p > 0 and top_p < 1.0:
+          token = top_p_sampling(logits, top_p, temp)
+        else:
+          token = mx.random.categorical(logits * (1 / temp))
 
-            return token
+      return token
 
-        y = x
+    y = x
 
-        if request_id not in self.request_cache:
-            self.init_cache(request_id)
-        output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
+    if request_id not in self.request_cache:
+      self.init_cache(request_id)
+    output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
 
-        if self.shard.is_last_layer():
-            logits = output[:, -1, :]
-            y = sample(logits)
-            return y
-        else:
-            return output
+    if self.shard.is_last_layer():
+      logits = output[:, -1, :]
+      y = sample(logits)
+      return y
+    else:
+      return output
 
-    def __call__(
-            self,
-            x,
-            temp: float = 0.0,
-            top_p: float = 1.0,
-        logit_bias: Optional[Dict[int, float]] = None,
-    ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-        return self.step(x, temp, top_p, logit_bias)
+  def __call__(
+    self,
+    x,
+    temp: float = 0.0,
+    top_p: float = 1.0,
+    logit_bias: Optional[Dict[int, float]] = None,
+  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
+    return self.step(x, temp, top_p, logit_bias)
 
-    def init_cache(self, request_id: str):
-        kv_heads = (
-            [self.model.n_kv_heads] * len(self.model.layers)
-            if isinstance(self.model.n_kv_heads, int)
-            else self.model.n_kv_heads
-        )
-        self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]
+  def init_cache(self, request_id: str):
+    kv_heads = (
+      [self.model.n_kv_heads] * len(self.model.layers)
+      if isinstance(self.model.n_kv_heads, int)
+      else self.model.n_kv_heads
+    )
+    self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]

+ 200 - 187
exo/inference/mlx/sharded_utils.py

@@ -19,213 +19,226 @@ from mlx_lm.tuner.utils import apply_lora_layers
 
 from ..shard import Shard
 
+
 class ModelNotFoundError(Exception):
-    def __init__(self, message):
-        self.message = message
-        super().__init__(self.message)
+  def __init__(self, message):
+    self.message = message
+    super().__init__(self.message)
+
 
 MODEL_REMAPPING = {
-    "sharded_mistral": "sharded_llama",  # mistral is compatible with llama
-    "sharded_phi-msft": "sharded_phixtral",
+  "sharded_mistral": "sharded_llama",  # mistral is compatible with llama
+  "sharded_phi-msft": "sharded_phixtral",
 }
 
+
 def _get_classes(config: dict):
-    """
-    Retrieve the model and model args classes based on the configuration.
+  """
+  Retrieve the model and model args classes based on the configuration.
 
-    Args:
-        config (dict): The model configuration.
+  Args:
+  config (dict): The model configuration.
 
-    Returns:
-        A tuple containing the Model class and the ModelArgs class.
-    """
-    model_type = config["model_type"]
-    model_type = MODEL_REMAPPING.get(model_type, model_type)
-    try:
-        arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
-    except ImportError:
-        msg = f"Model type {model_type} not supported."
-        logging.error(msg)
-        raise ValueError(msg)
+  Returns:
+  A tuple containing the Model class and the ModelArgs class.
+  """
+  model_type = config["model_type"]
+  model_type = MODEL_REMAPPING.get(model_type, model_type)
+  try:
+    arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
+  except ImportError:
+    msg = f"Model type {model_type} not supported."
+    logging.error(msg)
+    raise ValueError(msg)
+
+  return arch.Model, arch.ModelArgs
 
-    return arch.Model, arch.ModelArgs
 
 def load_config(model_path: Path) -> dict:
-    try:
-        with open(model_path / "config.json", "r") as f:
-            config = json.load(f)
-    except FileNotFoundError:
-        logging.error(f"Config file not found in {model_path}")
-        raise
-    return config
+  try:
+    with open(model_path / "config.json", "r") as f:
+      config = json.load(f)
+  except FileNotFoundError:
+    logging.error(f"Config file not found in {model_path}")
+    raise
+  return config
+
 
 def load_model_shard(
-    model_path: Path,
-    shard: Shard,
-    lazy: bool = False,
-    model_config: dict = {},
+  model_path: Path,
+  shard: Shard,
+  lazy: bool = False,
+  model_config: dict = {},
 ) -> nn.Module:
-    """
-    Load and initialize the model from a given path.
-
-    Args:
-        model_path (Path): The path to load the model from.
-        lazy (bool): If False eval the model parameters to make sure they are
-            loaded in memory before returning, otherwise they will be loaded
-            when needed. Default: ``False``
-        model_config(dict, optional): Configuration parameters for the model.
-            Defaults to an empty dictionary.
-
-    Returns:
-        nn.Module: The loaded and initialized model.
-
-    Raises:
-        FileNotFoundError: If the weight files (.safetensors) are not found.
-        ValueError: If the model class or args class are not found or cannot be instantiated.
-    """
-
-    config = load_config(model_path)
-    config.update(model_config)
-
-    # TODO hack
-    config["model_type"] = f"sharded_{config['model_type']}"
-    config["shard"] = {
-        "model_id": model_path.name,
-        "start_layer": shard.start_layer,
-        "end_layer": shard.end_layer,
-        "n_layers": shard.n_layers
-    }
-
-    weight_files = glob.glob(str(model_path / "model*.safetensors"))
-
-    if not weight_files:
-        # Try weight for back-compat
-        weight_files = glob.glob(str(model_path / "weight*.safetensors"))
-
-    if not weight_files:
-        logging.error(f"No safetensors found in {model_path}")
-        raise FileNotFoundError(f"No safetensors found in {model_path}")
-
-    weights = {}
-    all_weights_keys = set()
-    for wf in weight_files:
-        weights_dict = mx.load(wf)
-        all_weights_keys.update(weights_dict.keys())
-        weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split('.')[2]) <= shard.end_layer})
-
-    model_class, model_args_class = _get_classes(config=config)
-
-    model_args = model_args_class.from_dict(config)
-    model = model_class(model_args)
-
-    if hasattr(model, "sanitize"):
-        weights = model.sanitize(weights)
-
-    if (quantization := config.get("quantization", None)) is not None:
-        nn.quantize(
-            model,
-            **quantization,
-            class_predicate=None,
-        )
-
-    filtered_weights = {}
-    for k, v in weights.items():
-        if k.startswith("model.layers."):
-            layer_num = int(k.split('.')[2])
-            if shard.start_layer <= layer_num <= shard.end_layer:
-                new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
-                filtered_weights[new_key] = v
-        else:
-            filtered_weights[k] = v
-    weights = filtered_weights
+  """
+  Load and initialize the model from a given path.
+
+  Args:
+  model_path (Path): The path to load the model from.
+  lazy (bool): If False eval the model parameters to make sure they are
+  loaded in memory before returning, otherwise they will be loaded
+  when needed. Default: ``False``
+  model_config(dict, optional): Configuration parameters for the model.
+  Defaults to an empty dictionary.
+
+  Returns:
+  nn.Module: The loaded and initialized model.
+
+  Raises:
+  FileNotFoundError: If the weight files (.safetensors) are not found.
+  ValueError: If the model class or args class are not found or cannot be instantiated.
+  """
+
+  config = load_config(model_path)
+  config.update(model_config)
+
+  # TODO hack
+  config["model_type"] = f"sharded_{config['model_type']}"
+  config["shard"] = {
+    "model_id": model_path.name,
+    "start_layer": shard.start_layer,
+    "end_layer": shard.end_layer,
+    "n_layers": shard.n_layers,
+  }
+
+  weight_files = glob.glob(str(model_path / "model*.safetensors"))
+
+  if not weight_files:
+    # Try weight for back-compat
+    weight_files = glob.glob(str(model_path / "weight*.safetensors"))
+
+  if not weight_files:
+    logging.error(f"No safetensors found in {model_path}")
+    raise FileNotFoundError(f"No safetensors found in {model_path}")
+
+  weights = {}
+  all_weights_keys = set()
+  for wf in weight_files:
+    weights_dict = mx.load(wf)
+    all_weights_keys.update(weights_dict.keys())
+    weights.update(
+      {
+        k: v
+        for k, v in weights_dict.items()
+        if not k.startswith("model.layers.") or shard.start_layer <= int(k.split(".")[2]) <= shard.end_layer
+      }
+    )
+
+  model_class, model_args_class = _get_classes(config=config)
+
+  model_args = model_args_class.from_dict(config)
+  model = model_class(model_args)
+
+  if hasattr(model, "sanitize"):
+    weights = model.sanitize(weights)
+
+  if (quantization := config.get("quantization", None)) is not None:
+    nn.quantize(
+      model,
+      **quantization,
+      class_predicate=None,
+    )
+
+  filtered_weights = {}
+  for k, v in weights.items():
+    if k.startswith("model.layers."):
+      layer_num = int(k.split(".")[2])
+      if shard.start_layer <= layer_num <= shard.end_layer:
+        new_key = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:])
+        filtered_weights[new_key] = v
+    else:
+      filtered_weights[k] = v
+  weights = filtered_weights
+
+  model.load_weights(list(weights.items()), strict=False)
+
+  if not lazy:
+    mx.eval(model.parameters())
+
+  model.eval()
+  return model
 
-    model.load_weights(list(weights.items()), strict=False)
-
-    if not lazy:
-        mx.eval(model.parameters())
-
-    model.eval()
-    return model
 
 async def snapshot_download_async(*args, **kwargs):
-    func = partial(snapshot_download, *args, **kwargs)
-    return await asyncio.get_event_loop().run_in_executor(None, func)
+  func = partial(snapshot_download, *args, **kwargs)
+  return await asyncio.get_event_loop().run_in_executor(None, func)
+
 
 async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
-    """
-    Ensures the model is available locally. If the path does not exist locally,
-    it is downloaded from the Hugging Face Hub.
-
-    Args:
-        path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-        revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
-
-    Returns:
-        Path: The path to the model.
-    """
-    model_path = Path(path_or_hf_repo)
-    if not model_path.exists():
-        try:
-            model_path = Path(
-                await snapshot_download_async(
-                    repo_id=path_or_hf_repo,
-                    revision=revision,
-                    allow_patterns=[
-                        "*.json",
-                        "*.safetensors",
-                        "*.py",
-                        "tokenizer.model",
-                        "*.tiktoken",
-                        "*.txt",
-                    ],
-                )
-            )
-        except RepositoryNotFoundError:
-            raise ModelNotFoundError(
-                f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
-                "Please make sure you specified the local path or Hugging Face"
-                " repo id correctly.\nIf you are trying to access a private or"
-                " gated Hugging Face repo, make sure you are authenticated:\n"
-                "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
-            ) from None
-    return model_path
+  """
+  Ensures the model is available locally. If the path does not exist locally,
+  it is downloaded from the Hugging Face Hub.
+
+  Args:
+  path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
+  revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
+
+  Returns:
+  Path: The path to the model.
+  """
+  model_path = Path(path_or_hf_repo)
+  if not model_path.exists():
+    try:
+      model_path = Path(
+        await snapshot_download_async(
+          repo_id=path_or_hf_repo,
+          revision=revision,
+          allow_patterns=[
+            "*.json",
+            "*.safetensors",
+            "*.py",
+            "tokenizer.model",
+            "*.tiktoken",
+            "*.txt",
+          ],
+        )
+      )
+    except RepositoryNotFoundError:
+      raise ModelNotFoundError(
+        f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
+        "Please make sure you specified the local path or Hugging Face"
+        " repo id correctly.\nIf you are trying to access a private or"
+        " gated Hugging Face repo, make sure you are authenticated:\n"
+        "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
+      ) from None
+  return model_path
 
 
 async def load_shard(
-    path_or_hf_repo: str,
-    shard: Shard,
-    tokenizer_config={},
-    model_config={},
-    adapter_path: Optional[str] = None,
-    lazy: bool = False,
+  path_or_hf_repo: str,
+  shard: Shard,
+  tokenizer_config={},
+  model_config={},
+  adapter_path: Optional[str] = None,
+  lazy: bool = False,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
-    """
-    Load the model and tokenizer from a given path or a huggingface repository.
-
-    Args:
-        path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-        tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-            Defaults to an empty dictionary.
-        model_config(dict, optional): Configuration parameters specifically for the model.
-            Defaults to an empty dictionary.
-        adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-            to the model. Default: ``None``.
-        lazy (bool): If False eval the model parameters to make sure they are
-            loaded in memory before returning, otherwise they will be loaded
-            when needed. Default: ``False``
-    Returns:
-        Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
-
-    Raises:
-        FileNotFoundError: If config file or safetensors are not found.
-        ValueError: If model class or args class are not found.
-    """
-    model_path = await get_model_path(path_or_hf_repo)
-
-    model = load_model_shard(model_path, shard, lazy, model_config)
-    if adapter_path is not None:
-        model = apply_lora_layers(model, adapter_path)
-        model.eval()
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
-
-    return model, tokenizer
+  """
+  Load the model and tokenizer from a given path or a huggingface repository.
+
+  Args:
+  path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
+  tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
+  Defaults to an empty dictionary.
+  model_config(dict, optional): Configuration parameters specifically for the model.
+  Defaults to an empty dictionary.
+  adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
+  to the model. Default: ``None``.
+  lazy (bool): If False eval the model parameters to make sure they are
+  loaded in memory before returning, otherwise they will be loaded
+  when needed. Default: ``False``
+  Returns:
+  Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
+
+  Raises:
+  FileNotFoundError: If config file or safetensors are not found.
+  ValueError: If model class or args class are not found.
+  """
+  model_path = await get_model_path(path_or_hf_repo)
+
+  model = load_model_shard(model_path, shard, lazy, model_config)
+  if adapter_path is not None:
+    model = apply_lora_layers(model, adapter_path)
+    model.eval()
+  tokenizer = load_tokenizer(model_path, tokenizer_config)
+
+  return model, tokenizer

+ 5 - 5
exo/inference/mlx/test_sharded_llama.py

@@ -23,17 +23,17 @@ max_tokens = 50
 resp = prompt_tokens
 full_generated_tokens = []
 for _ in range(max_tokens):
-    resp = full.step(resp)
-    full_generated_tokens.append(resp.item())
+  resp = full.step(resp)
+  full_generated_tokens.append(resp.item())
 
 print("full response: ", full_tokenizer.decode(full_generated_tokens))
 
 sharded_generated_tokens = []
 sharded_resp = prompt_tokens
 for _ in range(max_tokens):
-    resp1 = m1.step(sharded_resp)
-    sharded_resp = m2.step(resp1)
-    sharded_generated_tokens.append(sharded_resp.item())
+  resp1 = m1.step(sharded_resp)
+  sharded_resp = m2.step(resp1)
+  sharded_generated_tokens.append(sharded_resp.item())
 
 print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
 

+ 29 - 27
exo/inference/mlx/test_sharded_model.py

@@ -5,32 +5,34 @@ import mlx.nn as nn
 from typing import Optional
 import numpy as np
 
+
 class DummyModel(nn.Module):
-    def __init__(self, shard: Optional[Shard] = None):
-        self.shard = shard
-        self.layers = [
-            nn.Linear(8, 128),
-            nn.Linear(128, 128),
-            nn.Linear(128, 128),
-            nn.Linear(128, 128),
-            nn.Linear(128, 8),
-        ]
-
-        self.n_kv_heads = 4
-        self.head_dim = 4
-
-    def __call__(self, x, cache=None):
-        if self.shard:
-            for layer in self.layers[self.shard.start_layer:self.shard.end_layer+1]:
-                x = layer(x)
-            if self.shard.is_last_layer():
-                x =  x.reshape((1, 2, 4))
-        else:
-            for layer in self.layers:
-                x = layer(x)
-            x = x.reshape((1, 2, 4))
-
-        return x
+  def __init__(self, shard: Optional[Shard] = None):
+    self.shard = shard
+    self.layers = [
+      nn.Linear(8, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 8),
+    ]
+
+    self.n_kv_heads = 4
+    self.head_dim = 4
+
+  def __call__(self, x, cache=None):
+    if self.shard:
+      for layer in self.layers[self.shard.start_layer : self.shard.end_layer + 1]:
+        x = layer(x)
+      if self.shard.is_last_layer():
+        x = x.reshape((1, 2, 4))
+    else:
+      for layer in self.layers:
+        x = layer(x)
+      x = x.reshape((1, 2, 4))
+
+    return x
+
 
 model = DummyModel()
 model.save_weights("./test_weights.npz")
@@ -44,8 +46,8 @@ model.load_weights("./test_weights.npz")
 sharded_model1.load_weights("./test_weights.npz")
 sharded_model2.load_weights("./test_weights.npz")
 
-fullresp = model(mx.array([1,2,3,4,5,6,7,8]))
-resp1 = sharded_model1(mx.array([1,2,3,4,5,6,7,8]))
+fullresp = model(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
+resp1 = sharded_model1(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
 resp2 = sharded_model2(resp1)
 
 assert np.all(np.array(fullresp) == np.array(resp2))

+ 16 - 15
exo/inference/shard.py

@@ -1,22 +1,23 @@
 from dataclasses import dataclass
 
+
 @dataclass
 class Shard:
-    model_id: str
-    start_layer: int
-    end_layer: int
-    n_layers: int
+  model_id: str
+  start_layer: int
+  end_layer: int
+  n_layers: int
 
-    def is_first_layer(self) -> bool:
-        return self.start_layer == 0
+  def is_first_layer(self) -> bool:
+    return self.start_layer == 0
 
-    def is_last_layer(self) -> bool:
-        return self.end_layer == self.n_layers - 1
+  def is_last_layer(self) -> bool:
+    return self.end_layer == self.n_layers - 1
 
-    def to_dict(self) -> dict:
-        return {
-            "model_id": self.model_id,
-            "start_layer": self.start_layer,
-            "end_layer": self.end_layer,
-            "n_layers": self.n_layers
-        }
+  def to_dict(self) -> dict:
+    return {
+      "model_id": self.model_id,
+      "start_layer": self.start_layer,
+      "end_layer": self.end_layer,
+      "n_layers": self.n_layers,
+    }

+ 42 - 12
exo/inference/test_inference_engine.py

@@ -5,25 +5,55 @@ from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import asyncio
 import numpy as np
 
+
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
-    prompt = "In a single word only, what is the last name of the current president of the USA?"
-    resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-    next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
+async def test_inference_engine(
+  inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str
+):
+  prompt = "In a single word only, what is the last name of the current president of the USA?"
+  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
+    "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt
+  )
+  next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+    "A",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
+    input_data=resp_full,
+    inference_state=inference_state_full,
+  )
+
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(
+    "B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt
+  )
+  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp1,
+    inference_state=inference_state_1,
+  )
+  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
+    input_data=resp2,
+    inference_state=inference_state_2,
+  )
+  resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp3,
+    inference_state=inference_state_3,
+  )
 
-    resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-    resp2, inference_state_2, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state_1)
-    resp3, inference_state_3, _ = await inference_engine_1.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=resp2, inference_state=inference_state_2)
-    resp4, inference_state_4, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, inference_state=inference_state_3)
+  assert np.array_equal(resp_full, resp2)
+  assert np.array_equal(next_resp_full, resp4)
 
-    assert np.array_equal(resp_full, resp2)
-    assert np.array_equal(next_resp_full, resp4)
 
-asyncio.run(test_inference_engine(
+asyncio.run(
+  test_inference_engine(
     MLXDynamicShardInferenceEngine(),
     MLXDynamicShardInferenceEngine(),
     "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
-))
+  )
+)
 
 # TODO: Need more memory or a smaller model
 # asyncio.run(test_inference_engine(

+ 190 - 91
exo/inference/tinygrad/inference.py

@@ -16,17 +16,37 @@ import os
 
 MODEL_PARAMS = {
   "8B": {
-    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
-    "files": 1
+    "args": {
+      "dim": 4096,
+      "n_heads": 32,
+      "n_kv_heads": 8,
+      "n_layers": 32,
+      "norm_eps": 1e-5,
+      "rope_theta": 500000,
+      "vocab_size": 128256,
+      "hidden_dim": 14336,
+    },
+    "files": 1,
   },
   "70B": {
-    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
-    "files": 8
-  }
+    "args": {
+      "dim": 8192,
+      "n_heads": 64,
+      "n_kv_heads": 8,
+      "n_layers": 80,
+      "norm_eps": 1e-5,
+      "rope_theta": 500000,
+      "vocab_size": 128256,
+      "hidden_dim": 28672,
+    },
+    "files": 8,
+  },
 }
 
+
 class Tokenizer:
   pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+
   def __init__(self, model_path: str):
     mergeable_ranks = load_tiktoken_bpe(model_path)
     self.num_base_tokens = len(mergeable_ranks)
@@ -41,29 +61,38 @@ class Tokenizer:
       "<|end_header_id|>",
       "<|reserved_special_token_4|>",
       "<|eot_id|>",
-    ] + [
-      f"<|reserved_special_token_{i}|>"
-      for i in range(5, 256 - 5)
-    ]
+    ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
     self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
 
-    self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
+    self.model = tiktoken.Encoding(
+      name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens
+    )
 
   @property
-  def bos_id(self): return self.special_tokens["<|begin_of_text|>"]
+  def bos_id(self):
+    return self.special_tokens["<|begin_of_text|>"]
+
   @property
-  def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
+  def stop_tokens(self):
+    return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
 
   def decode(self, toks):
-     return self.model.decode([t for t in toks if t < self.num_base_tokens])
+    return self.model.decode([t for t in toks if t < self.num_base_tokens])
+
   def encode(self, text, allow_special=False):
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
 
+
 # **** helper functions ****
-async def fetch_async(url: str, name: Optional[Union[Path, str]] = None, subdir: Optional[str] = None,
-                      allow_caching=not os.getenv("DISABLE_HTTP_CACHE")) -> Path:
-    func = partial(fetch, url, name, subdir, allow_caching)
-    return await asyncio.get_event_loop().run_in_executor(None, func)
+async def fetch_async(
+  url: str,
+  name: Optional[Union[Path, str]] = None,
+  subdir: Optional[str] = None,
+  allow_caching=not os.getenv("DISABLE_HTTP_CACHE"),
+) -> Path:
+  func = partial(fetch, url, name, subdir, allow_caching)
+  return await asyncio.get_event_loop().run_in_executor(None, func)
+
 
 def concat_weights(models, device=None):
   def convert(name) -> Tensor:
@@ -73,11 +102,14 @@ def concat_weights(models, device=None):
     axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
     lazy_tensors = [data.to(device=device) for data in disk_tensors]
     return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+
   return {name: convert(name) for name in {name: None for model in models for name in model}}
 
-def load(fn:str):
-  if fn.endswith('.index.json'):
-    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+
+def load(fn: str):
+  if fn.endswith(".index.json"):
+    with open(fn) as fp:
+      weight_map = json.load(fp)["weight_map"]
     parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
     return {k: parts[n][k] for k, n in weight_map.items()}
   elif fn.endswith(".safetensors"):
@@ -85,6 +117,7 @@ def load(fn:str):
   else:
     return torch_load(fn)
 
+
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
   # build model
   linear = nn.Linear
@@ -93,44 +126,67 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=
 
   # load weights
   if model_path.is_dir():
-    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"))
-    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"))
-    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
+    if (model_path / "model.safetensors.index.json").exists():
+      weights = load(str(model_path / "model.safetensors.index.json"))
+    elif (model_path / "model.safetensors").exists():
+      weights = load(str(model_path / "model.safetensors"))
+    else:
+      weights = concat_weights(
+        [load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])],
+        device[0] if isinstance(device, tuple) else device,
+      )
   else:
     weights = load(str(model_path))
   if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"], shard=shard)
+    weights = convert_from_huggingface(
+      weights,
+      model,
+      MODEL_PARAMS[model_size]["args"]["n_heads"],
+      MODEL_PARAMS[model_size]["args"]["n_kv_heads"],
+      shard=shard,
+    )
   weights = fix_bf16(weights)
 
   with Context(BEAM=0):
     # quantize
     if quantize is not None:
       weights = linear.quantize(weights, device)
-      for _,v in weights.items(): v.realize()
+      for _, v in weights.items():
+        v.realize()
 
     # shard
     if isinstance(device, tuple):
-      for k,v in nn.state.get_state_dict(model).items():
-        if 'scale' in k: v.shard_(device, axis=None)  # from quantized
-        elif '.attention.' in k: v.shard_(device, axis=-1)
-        elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
-        elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
-        elif '.feed_forward.' in k: v.shard_(device, axis=-1)
-        elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
-        elif 'output.weight' in k: v.shard_(device, axis=0)
-        else: v.shard_(device, axis=None)
+      for k, v in nn.state.get_state_dict(model).items():
+        if "scale" in k:
+          v.shard_(device, axis=None)  # from quantized
+        elif ".attention." in k:
+          v.shard_(device, axis=-1)
+        elif ".feed_forward.w1." in k:
+          v.shard_(device, axis=0)
+        elif ".feed_forward.w3." in k:
+          v.shard_(device, axis=0)
+        elif ".feed_forward." in k:
+          v.shard_(device, axis=-1)
+        elif "tok_embeddings.weight" in k:
+          v.shard_(device, axis=0)
+        elif "output.weight" in k:
+          v.shard_(device, axis=0)
+        else:
+          v.shard_(device, axis=None)
 
     # replace weights in model
     load_state_dict(model, weights, strict=False, consume=True)
   return model
 
+
 # default settings
-TEMPERATURE = 0 # 0.85
+TEMPERATURE = 0  # 0.85
 TOP_K = 25
 TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 
+
 def prefill(model, toks, start_pos=0):
   # prefill the model
   for tok in tqdm(toks):
@@ -139,71 +195,114 @@ def prefill(model, toks, start_pos=0):
     start_pos += 1
   return start_pos
 
+
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-    def __init__(self):
-        self.shard = None
+  def __init__(self):
+    self.shard = None
+
+  async def infer_prompt(
+    self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None
+  ) -> (np.ndarray, str, bool):
+    # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
+    await self.ensure_shard(shard)
+    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
-    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
-        await self.ensure_shard(shard)
-        start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    toks = self.tokenizer.encode(prompt)
+    start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
+    last_tok = toks[-1]
 
-        toks = self.tokenizer.encode(prompt)
-        start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
-        last_tok = toks[-1]
+    output_data = np.array(
+      [self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()]
+    )
+    if output_data.size == 1:
+      start_pos += 1
 
-        output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-        if output_data.size == 1:
-           start_pos += 1
+    return (
+      output_data,
+      json.dumps({"start_pos": start_pos}),
+      output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens,
+    )
 
-        return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
+  async def infer_tensor(
+    self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None
+  ) -> (np.ndarray, str, bool):
+    await self.ensure_shard(shard)
+    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
-    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        await self.ensure_shard(shard)
-        start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    output_data: np.ndarray = np.array(
+      [self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()]
+    )
+    if output_data.size == 1:
+      start_pos += 1
 
-        output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-        if output_data.size == 1:
-           start_pos += 1
+    return (
+      output_data,
+      json.dumps({"start_pos": start_pos}),
+      output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens,
+    )
 
-        return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard:
+      return
 
-    async def ensure_shard(self, shard: Shard):
-        if self.shard == shard:
-            return
+    model_path = Path(shard.model_id)
+    models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
+    model_path = models_dir / shard.model_id
+    size = "8B"
+    if Path(model_path / "model.safetensors.index.json").exists():
+      model = model_path
+    else:
 
-        model_path = Path(shard.model_id)
-        models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
-        model_path = models_dir / shard.model_id
+      if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
+      if shard.model_id.lower().find("llama3-8b-sfr") != -1:
+        await fetch_async(
+          "https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model",
+          "tokenizer.model",
+          subdir=shard.model_id,
+        )
+        await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors",
+          "model-00001-of-00004.safetensors",
+          subdir=shard.model_id,
+        )
+        await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors",
+          "model-00002-of-00004.safetensors",
+          subdir=shard.model_id,
+        )
+        await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors",
+          "model-00003-of-00004.safetensors",
+          subdir=shard.model_id,
+        )
+        await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors",
+          "model-00004-of-00004.safetensors",
+          subdir=shard.model_id,
+        )
+        model = await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json",
+          "model.safetensors.index.json",
+          subdir=shard.model_id,
+        )
         size = "8B"
-        if Path(model_path / "model.safetensors.index.json").exists():
-            model = model_path
-        else:
+      elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
+        raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
+        # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
+        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
+        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
+        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
+        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
+        # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
+        # size = "70B"
+      else:
+        raise ValueError(
+          f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}"
+        )
+
+    model = build_transformer(model_path, shard=shard, model_size=size)
+    tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
 
-            if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
-            if shard.model_id.lower().find("llama3-8b-sfr") != -1:
-                await fetch_async("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-                model = await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
-                size = "8B"
-            elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
-                raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
-                # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-                # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
-                # size = "70B"
-            else:
-                raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
-
-        model = build_transformer(model_path, shard=shard, model_size=size)
-        tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
-
-        self.shard = shard
-        self.model = model
-        self.tokenizer = tokenizer
+    self.shard = shard
+    self.model = model
+    self.tokenizer = tokenizer

+ 146 - 43
exo/inference/tinygrad/models/llama.py

@@ -3,22 +3,27 @@ from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from exo.inference.shard import Shard
 
+
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
+  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
+
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
-  a,b = A[..., 0:1], A[..., 1:2]
-  ro = a*c - b*d
-  co = a*d + b*c
+  a, b = A[..., 0:1], A[..., 1:2]
+  ro = a * c - b * d
+  co = a * d + b * c
   return ro.cat(co, dim=-1)
 
-def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
-  assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
+
+def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
+  assert (
+    freqs_cis.shape[1] == xq.shape[1] == xk.shape[1]
+  ), f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
   assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
@@ -27,16 +32,21 @@ def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Te
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
 
-def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
+
+def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
-  if n_rep == 1: return x
+  if n_rep == 1:
+    return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
 
+
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = (
+      n_kv_heads if n_kv_heads is not None else n_heads
+    )  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
@@ -46,7 +56,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
     xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
@@ -57,17 +67,23 @@ class Attention:
 
     # create kv cache
     if not hasattr(self, "cache_kv"):
-      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
+      self.cache_kv = (
+        Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype)
+        .contiguous()
+        .realize()
+      )
       if isinstance(x.device, tuple):
         # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
         self.cache_kv.shard_((x.device), axis=None).realize()
 
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(
+      Tensor.stack(xk, xv)
+    ).realize()
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -75,26 +91,39 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
 
+
 class FeedForward:
-  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
+  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
+    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
+
+  def __call__(self, x: Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
 
-  def __call__(self, x:Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 class TransformerBlock:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
+  def __init__(
+    self,
+    dim: int,
+    hidden_dim: int,
+    n_heads: int,
+    n_kv_heads: int,
+    norm_eps: float,
+    max_context: int,
+    linear=nn.Linear,
+    feed_forward=FeedForward,
+  ):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
+
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -102,7 +131,8 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
 
   # if temperature is very low just use argmax
-  if temp < 1e-6: return logits.argmax()
+  if temp < 1e-6:
+    return logits.argmax()
 
   # alpha sampling
   if af or ap:
@@ -116,10 +146,16 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   # softmax
   t = (logits / temp).softmax()
 
-  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
+  counter, counter2 = (
+    Tensor.arange(t.numel(), device=logits.device).contiguous(),
+    Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous(),
+  )
   # top k
   if k:
-    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
+    output, output_indices = (
+      Tensor.zeros(k, device=logits.device).contiguous(),
+      Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous(),
+    )
     for i in range(k):
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
@@ -144,9 +180,30 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
   return output_token
 
+
 class Transformer:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
-    self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
+  def __init__(
+    self,
+    dim: int,
+    hidden_dim: int,
+    n_heads: int,
+    n_layers: int,
+    norm_eps: float,
+    vocab_size,
+    shard: Shard,
+    linear=nn.Linear,
+    n_kv_heads=None,
+    rope_theta=10000,
+    max_context=1024,
+    jit=True,
+    feed_forward=FeedForward,
+  ):
+    self.layers = [
+      TransformerBlock(
+        dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward
+      )
+      for _ in range(shard.end_layer - shard.start_layer + 1)
+    ]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
@@ -155,13 +212,28 @@ class Transformer:
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
 
-  def forward(self, h:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+  def forward(
+    self,
+    h: Tensor,
+    start_pos: Union[Variable, int],
+    temperature: float,
+    top_k: int,
+    top_p: float,
+    alpha_f: float,
+    alpha_p: float,
+  ):
     seqlen = h.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
 
     if self.shard.is_first_layer():
       h = self.tok_embeddings(h)
-    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
+    mask = (
+      Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device)
+      .triu(start_pos + 1)
+      .realize()
+      if seqlen > 1
+      else None
+    )
 
     for i, layer in enumerate(self.layers):
       h = layer(h, start_pos, freqs_cis, mask)
@@ -169,12 +241,21 @@ class Transformer:
       #   print(f"layer {i}: {str(h.numpy())[:60]}")
 
     if self.shard.is_last_layer():
-        logits = self.output(self.norm(h)).float()[:, -1, :]
-        return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+      logits = self.output(self.norm(h)).float()[:, -1, :]
+      return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
     else:
       return h.realize()
 
-  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
+  def __call__(
+    self,
+    tokens: Tensor,
+    start_pos: Variable,
+    temperature: float = 0.0,
+    top_k: int = 0,
+    top_p: float = 0.8,
+    alpha_f: float = 0.0,
+    alpha_p: float = 0.0,
+  ):
     # TODO: better way to handle the first call v.s. the rest?
     # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
     #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
@@ -185,29 +266,48 @@ class Transformer:
       if hasattr(layer.attention, "cache_kv"):
         layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
 
+
 # *** helpers ***
 
-def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
+
+def convert_from_huggingface(
+  weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard
+):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
   keymap = {
     "model.embed_tokens.weight": "tok_embeddings.weight",
-    **{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
-    **{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
-    **{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
+    **{
+      f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
+      for l in range(len(model.layers))
+    },
+    **{
+      f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
+      for x in ["q", "k", "v", "o"]
+      for l in range(len(model.layers))
+    },
+    **{
+      f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
+      for l in range(len(model.layers))
+    },
+    **{
+      f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
+      for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
+      for l in range(len(model.layers))
+    },
     "model.norm.weight": "norm.weight",
     "lm_head.weight": "output.weight",
   }
   sd = {}
   for k, v in weights.items():
-    if ".rotary_emb." in k: continue
+    if ".rotary_emb." in k:
+      continue
     v = v.to(Device.DEFAULT)
     if "model.layers" in k:
-      layer_num = int(k.split('.')[2])
+      layer_num = int(k.split(".")[2])
       if shard.start_layer <= layer_num <= shard.end_layer:
-          k = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
+        k = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:])
       else:
         continue
 
@@ -218,9 +318,12 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
     sd[keymap[k]] = v
   return sd
 
-def fix_bf16(weights:Dict[Any, Tensor]):
+
+def fix_bf16(weights: Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
   # TODO: check if device supports bf16
-  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+  return {
+    k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()
+  }

+ 1 - 1
exo/networking/__init__.py

@@ -2,4 +2,4 @@ from .discovery import Discovery
 from .peer_handle import PeerHandle
 from .server import Server
 
-__all__ = ['Discovery', 'PeerHandle', 'Server']
+__all__ = ["Discovery", "PeerHandle", "Server"]

+ 10 - 9
exo/networking/discovery.py

@@ -2,15 +2,16 @@ from abc import ABC, abstractmethod
 from typing import List
 from .peer_handle import PeerHandle
 
+
 class Discovery(ABC):
-    @abstractmethod
-    async def start(self) -> None:
-        pass
+  @abstractmethod
+  async def start(self) -> None:
+    pass
 
-    @abstractmethod
-    async def stop(self) -> None:
-        pass
+  @abstractmethod
+  async def stop(self) -> None:
+    pass
 
-    @abstractmethod
-    async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-        pass
+  @abstractmethod
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    pass

+ 182 - 137
exo/networking/grpc/grpc_discovery.py

@@ -9,147 +9,192 @@ from .grpc_peer_handle import GRPCPeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo import DEBUG_DISCOVERY
 
+
 class ListenProtocol(asyncio.DatagramProtocol):
-    def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
-        super().__init__()
-        self.on_message = on_message
-        self.loop = asyncio.get_event_loop()
+  def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
+    super().__init__()
+    self.on_message = on_message
+    self.loop = asyncio.get_event_loop()
 
-    def connection_made(self, transport):
-        self.transport = transport
+  def connection_made(self, transport):
+    self.transport = transport
 
-    def datagram_received(self, data, addr):
-        asyncio.create_task(self.on_message(data, addr))
+  def datagram_received(self, data, addr):
+    asyncio.create_task(self.on_message(data, addr))
 
 
 class GRPCDiscovery(Discovery):
-    def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES, discovery_timeout: int = 30):
-        self.node_id = node_id
-        self.node_port = node_port
-        self.device_capabilities = device_capabilities
-        self.listen_port = listen_port
-        self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
-        self.broadcast_interval = broadcast_interval
-        self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
-        self.broadcast_task = None
-        self.listen_task = None
-        self.cleanup_task = None
-        self.discovery_timeout = discovery_timeout
-
-    async def start(self):
-        self.device_capabilities = device_capabilities()
-        self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
-        self.listen_task = asyncio.create_task(self.task_listen_for_peers())
-        self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
-
-    async def stop(self):
-        if self.broadcast_task:
-            self.broadcast_task.cancel()
-        if self.listen_task:
-            self.listen_task.cancel()
-        if self.cleanup_task:
-            self.cleanup_task.cancel()
-        if self.broadcast_task or self.listen_task or self.cleanup_task:
-            await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
-
-    async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-        if DEBUG_DISCOVERY >= 2: print("Starting peer discovery process...")
-
+  def __init__(
+    self,
+    node_id: str,
+    node_port: int,
+    listen_port: int,
+    broadcast_port: int = None,
+    broadcast_interval: int = 1,
+    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    discovery_timeout: int = 30,
+  ):
+    self.node_id = node_id
+    self.node_port = node_port
+    self.device_capabilities = device_capabilities
+    self.listen_port = listen_port
+    self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
+    self.broadcast_interval = broadcast_interval
+    self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
+    self.broadcast_task = None
+    self.listen_task = None
+    self.cleanup_task = None
+    self.discovery_timeout = discovery_timeout
+
+  async def start(self):
+    self.device_capabilities = device_capabilities()
+    self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
+    self.listen_task = asyncio.create_task(self.task_listen_for_peers())
+    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
+
+  async def stop(self):
+    if self.broadcast_task:
+      self.broadcast_task.cancel()
+    if self.listen_task:
+      self.listen_task.cancel()
+    if self.cleanup_task:
+      self.cleanup_task.cancel()
+    if self.broadcast_task or self.listen_task or self.cleanup_task:
+      await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if DEBUG_DISCOVERY >= 2:
+      print("Starting peer discovery process...")
+
+    if wait_for_peers > 0:
+      while len(self.known_peers) == 0:
+        if DEBUG_DISCOVERY >= 2:
+          print("No peers discovered yet, retrying in 1 second...")
+        await asyncio.sleep(1)  # Keep trying to find peers
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
+
+    grace_period = 5  # seconds
+    while True:
+      initial_peer_count = len(self.known_peers)
+      if DEBUG_DISCOVERY >= 2:
+        print(
+          f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more..."
+        )
+      if len(self.known_peers) == initial_peer_count:
         if wait_for_peers > 0:
-            while len(self.known_peers) == 0:
-                if DEBUG_DISCOVERY >= 2: print("No peers discovered yet, retrying in 1 second...")
-                await asyncio.sleep(1)  # Keep trying to find peers
-            if DEBUG_DISCOVERY >= 2: print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
-
-        grace_period = 5  # seconds
-        while True:
-            initial_peer_count = len(self.known_peers)
-            if DEBUG_DISCOVERY >= 2: print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
-            if len(self.known_peers) == initial_peer_count:
-                if wait_for_peers > 0:
-                    await asyncio.sleep(grace_period)
-                    if DEBUG_DISCOVERY >= 2: print(f"Waiting additional {wait_for_peers} seconds for more peers.")
-                    wait_for_peers = 0
-                else:
-                    if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
-                    break  # No new peers found in the grace period, we are done
-
-        return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
-
-    async def task_broadcast_presence(self):
-        transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-                    lambda: asyncio.DatagramProtocol(),
-                    local_addr=('0.0.0.0', 0),
-                    family=socket.AF_INET)
-        sock = transport.get_extra_info('socket')
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-
-        message = json.dumps({
-            "type": "discovery",
-            "node_id": self.node_id,
-            "grpc_port": self.node_port,
-            "device_capabilities": self.device_capabilities.to_dict()
-        }).encode('utf-8')
-
-        while True:
-            try:
-                if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
-                transport.sendto(message, ('<broadcast>', self.broadcast_port))
-                await asyncio.sleep(self.broadcast_interval)
-            except Exception as e:
-                print(f"Error in broadcast presence: {e}")
-                import traceback
-                print(traceback.format_exc())
-
-    async def on_listen_message(self, data, addr):
-        if not data:
-            return
-
-        decoded_data = data.decode('utf-8', errors='ignore')
-        
-        # Check if the decoded data starts with a valid JSON character
-        if not (decoded_data.strip() and decoded_data.strip()[0] in '{['):
-            if DEBUG_DISCOVERY >= 2: print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
-            return
-
-        try:
-            decoder = json.JSONDecoder(strict=False)
-            message = decoder.decode(decoded_data)
-        except json.JSONDecodeError as e:
-            if DEBUG_DISCOVERY >= 2: print(f"Error decoding JSON data from {addr}: {e}")
-            return
-
-        if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
-
-        if message['type'] == 'discovery' and message['node_id'] != self.node_id:
-            peer_id = message['node_id']
-            peer_host = addr[0]
-            peer_port = message['grpc_port']
-            device_capabilities = DeviceCapabilities(**message['device_capabilities'])
-            if peer_id not in self.known_peers:
-                self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time())
-                if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
-            self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
-
-    async def task_listen_for_peers(self):
-        await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
-        if DEBUG_DISCOVERY >= 2: print("Started listen task")
-
-    async def task_cleanup_peers(self):
-        while True:
-            try:
-                current_time = time.time()
-                peers_to_remove = [
-                    peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if
-                    (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
-                ]
-                if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
-                if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
-                for peer_id in peers_to_remove:
-                    if peer_id in self.known_peers: del self.known_peers[peer_id]
-                    if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
-                await asyncio.sleep(self.broadcast_interval)
-            except Exception as e:
-                print(f"Error in cleanup peers: {e}")
-                import traceback
-                print(traceback.format_exc())
+          await asyncio.sleep(grace_period)
+          if DEBUG_DISCOVERY >= 2:
+            print(f"Waiting additional {wait_for_peers} seconds for more peers.")
+          wait_for_peers = 0
+        else:
+          if DEBUG_DISCOVERY >= 2:
+            print("No new peers discovered in the last grace period. Ending discovery process.")
+          break  # No new peers found in the grace period, we are done
+
+    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+
+  async def task_broadcast_presence(self):
+    transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
+      lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET
+    )
+    sock = transport.get_extra_info("socket")
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+
+    message = json.dumps(
+      {
+        "type": "discovery",
+        "node_id": self.node_id,
+        "grpc_port": self.node_port,
+        "device_capabilities": self.device_capabilities.to_dict(),
+      }
+    ).encode("utf-8")
+
+    while True:
+      try:
+        if DEBUG_DISCOVERY >= 3:
+          print(f"Broadcast presence: {message}")
+        transport.sendto(message, ("<broadcast>", self.broadcast_port))
+        await asyncio.sleep(self.broadcast_interval)
+      except Exception as e:
+        print(f"Error in broadcast presence: {e}")
+        import traceback
+
+        print(traceback.format_exc())
+
+  async def on_listen_message(self, data, addr):
+    if not data:
+      return
+
+    decoded_data = data.decode("utf-8", errors="ignore")
+
+    # Check if the decoded data starts with a valid JSON character
+    if not (decoded_data.strip() and decoded_data.strip()[0] in "{["):
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
+      return
+
+    try:
+      decoder = json.JSONDecoder(strict=False)
+      message = decoder.decode(decoded_data)
+    except json.JSONDecodeError as e:
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Error decoding JSON data from {addr}: {e}")
+      return
+
+    if DEBUG_DISCOVERY >= 2:
+      print(f"received from peer {addr}: {message}")
+
+    if message["type"] == "discovery" and message["node_id"] != self.node_id:
+      peer_id = message["node_id"]
+      peer_host = addr[0]
+      peer_port = message["grpc_port"]
+      device_capabilities = DeviceCapabilities(**message["device_capabilities"])
+      if peer_id not in self.known_peers:
+        self.known_peers[peer_id] = (
+          GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
+          time.time(),
+          time.time(),
+        )
+        if DEBUG_DISCOVERY >= 2:
+          print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
+      self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
+
+  async def task_listen_for_peers(self):
+    await asyncio.get_event_loop().create_datagram_endpoint(
+      lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port)
+    )
+    if DEBUG_DISCOVERY >= 2:
+      print("Started listen task")
+
+  async def task_cleanup_peers(self):
+    while True:
+      try:
+        current_time = time.time()
+        peers_to_remove = [
+          peer_handle.id()
+          for peer_handle, connected_at, last_seen in self.known_peers.values()
+          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout)
+          or current_time - last_seen > self.discovery_timeout
+        ]
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:",
+            {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}"
+              for peer_handle, connected_at, last_seen in self.known_peers.values()
+            },
+          )
+        if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
+          print(f"Cleaning up peers: {peers_to_remove}")
+        for peer_id in peers_to_remove:
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+          if DEBUG_DISCOVERY >= 2:
+            print(f"Removed peer {peer_id} due to inactivity.")
+        await asyncio.sleep(self.broadcast_interval)
+      except Exception as e:
+        print(f"Error in cleanup peers: {e}")
+        import traceback
+
+        print(traceback.format_exc())

+ 102 - 81
exo/networking/grpc/grpc_peer_handle.py

@@ -11,85 +11,106 @@ from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities
 
+
 class GRPCPeerHandle(PeerHandle):
-    def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
-        self._id = id
-        self.address = address
-        self._device_capabilities = device_capabilities
-        self.channel = None
-        self.stub = None
-
-    def id(self) -> str:
-        return self._id
-
-    def device_capabilities(self) -> DeviceCapabilities:
-        return self._device_capabilities
-
-    async def connect(self):
-        self.channel = grpc.aio.insecure_channel(self.address, options=[
-            ('grpc.max_metadata_size', 32*1024*1024)
-        ])
-        self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
-
-    async def is_connected(self) -> bool:
-        return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
-
-    async def disconnect(self):
-        if self.channel:
-            await self.channel.close()
-        self.channel = None
-        self.stub = None
-
-    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
-        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id, inference_state=inference_state)
-        response = await self.stub.SendPrompt(request)
-
-        if not response.tensor_data or not response.shape or not response.dtype:
-            return None
-
-        return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-
-    async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
-        request = node_service_pb2.TensorRequest(
-            shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
-            tensor = node_service_pb2.Tensor(
-                tensor_data=tensor.tobytes(),
-                shape=tensor.shape,
-                dtype=str(tensor.dtype)
-            ),
-            request_id=request_id,
-            inference_state=inference_state
-        )
-        response = await self.stub.SendTensor(request)
-
-        if not response.tensor_data or not response.shape or not response.dtype:
-            return None
-
-        return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
-        response = await self.stub.GetInferenceResult(request)
-        if response.tensor is None:
-            return None, response.is_finished
-        return np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape), response.is_finished
-
-    async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
-        request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
-        response = await self.stub.CollectTopology(request)
-        topology = Topology()
-        for node_id, capabilities in response.nodes.items():
-            device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
-            topology.update_node(node_id, device_capabilities)
-        for node_id, peers in response.peer_graph.items():
-            for peer_id in peers.peer_ids:
-                topology.add_edge(node_id, peer_id)
-        return topology
-
-    async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-        request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
-        await self.stub.SendResult(request)
-
-    async def send_opaque_status(self, request_id: str, status: str) -> None:
-        request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
-        await self.stub.SendOpaqueStatus(request)
+  def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
+    self._id = id
+    self.address = address
+    self._device_capabilities = device_capabilities
+    self.channel = None
+    self.stub = None
+
+  def id(self) -> str:
+    return self._id
+
+  def device_capabilities(self) -> DeviceCapabilities:
+    return self._device_capabilities
+
+  async def connect(self):
+    self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32 * 1024 * 1024)])
+    self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+
+  async def is_connected(self) -> bool:
+    return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
+
+  async def disconnect(self):
+    if self.channel:
+      await self.channel.close()
+    self.channel = None
+    self.stub = None
+
+  async def send_prompt(
+    self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
+  ) -> Optional[np.array]:
+    request = node_service_pb2.PromptRequest(
+      prompt=prompt,
+      shard=node_service_pb2.Shard(
+        model_id=shard.model_id,
+        start_layer=shard.start_layer,
+        end_layer=shard.end_layer,
+        n_layers=shard.n_layers,
+      ),
+      request_id=request_id,
+      inference_state=inference_state,
+    )
+    response = await self.stub.SendPrompt(request)
+
+    if not response.tensor_data or not response.shape or not response.dtype:
+      return None
+
+    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+
+  async def send_tensor(
+    self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None
+  ) -> Optional[np.array]:
+    request = node_service_pb2.TensorRequest(
+      shard=node_service_pb2.Shard(
+        model_id=shard.model_id,
+        start_layer=shard.start_layer,
+        end_layer=shard.end_layer,
+        n_layers=shard.n_layers,
+      ),
+      tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
+      request_id=request_id,
+      inference_state=inference_state,
+    )
+    response = await self.stub.SendTensor(request)
+
+    if not response.tensor_data or not response.shape or not response.dtype:
+      return None
+
+    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
+    response = await self.stub.GetInferenceResult(request)
+    if response.tensor is None:
+      return None, response.is_finished
+    return (
+      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(
+        response.tensor.shape
+      ),
+      response.is_finished,
+    )
+
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+    request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
+    response = await self.stub.CollectTopology(request)
+    topology = Topology()
+    for node_id, capabilities in response.nodes.items():
+      device_capabilities = DeviceCapabilities(
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops
+      )
+      topology.update_node(node_id, device_capabilities)
+    for node_id, peers in response.peer_graph.items():
+      for peer_id in peers.peer_ids:
+        topology.add_edge(node_id, peer_id)
+    return topology
+
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
+    await self.stub.SendResult(request)
+
+  async def send_opaque_status(self, request_id: str, status: str) -> None:
+    request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
+    await self.stub.SendOpaqueStatus(request)

+ 106 - 65
exo/networking/grpc/grpc_server.py

@@ -8,78 +8,119 @@ from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 
+
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
-    def __init__(self, node: Node, host: str, port: int):
-        self.node = node
-        self.host = host
-        self.port = port
-        self.server = None
+  def __init__(self, node: Node, host: str, port: int):
+    self.node = node
+    self.host = host
+    self.port = port
+    self.server = None
 
-    async def start(self) -> None:
-        self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
-            ('grpc.max_metadata_size', 32*1024*1024),
-            ('grpc.max_send_message_length', 128 * 1024 * 1024),
-            ('grpc.max_receive_message_length', 128 * 1024 * 1024),
-        ])
-        node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
-        listen_addr = f'{self.host}:{self.port}'
-        self.server.add_insecure_port(listen_addr)
-        await self.server.start()
-        if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
+  async def start(self) -> None:
+    self.server = grpc.aio.server(
+      futures.ThreadPoolExecutor(max_workers=10),
+      options=[
+        ("grpc.max_metadata_size", 32 * 1024 * 1024),
+        ("grpc.max_send_message_length", 128 * 1024 * 1024),
+        ("grpc.max_receive_message_length", 128 * 1024 * 1024),
+      ],
+    )
+    node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
+    listen_addr = f"{self.host}:{self.port}"
+    self.server.add_insecure_port(listen_addr)
+    await self.server.start()
+    if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
 
-    async def stop(self) -> None:
-        if self.server:
-            await self.server.stop(grace=5)
-            await self.server.wait_for_termination()
-            if DEBUG >= 1: print("Server stopped and all connections are closed")
+  async def stop(self) -> None:
+    if self.server:
+      await self.server.stop(grace=5)
+      await self.server.wait_for_termination()
+      if DEBUG >= 1: print("Server stopped and all connections are closed")
 
-    async def SendPrompt(self, request, context):
-        shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
-        prompt = request.prompt
-        request_id = request.request_id
-        result = await self.node.process_prompt(shard, prompt, request_id)
-        if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
-        tensor_data = result.tobytes() if result is not None else None
-        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+  async def SendPrompt(self, request, context):
+    shard = Shard(
+      model_id=request.shard.model_id,
+      start_layer=request.shard.start_layer,
+      end_layer=request.shard.end_layer,
+      n_layers=request.shard.n_layers,
+    )
+    prompt = request.prompt
+    request_id = request.request_id
+    result = await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
+    tensor_data = result.tobytes() if result is not None else None
+    return (
+      node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
+      if result is not None
+      else node_service_pb2.Tensor()
+    )
 
-    async def SendTensor(self, request, context):
-        shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
-        tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
-        request_id = request.request_id
-        inference_state = request.inference_state
+  async def SendTensor(self, request, context):
+    shard = Shard(
+      model_id=request.shard.model_id,
+      start_layer=request.shard.start_layer,
+      end_layer=request.shard.end_layer,
+      n_layers=request.shard.n_layers,
+    )
+    tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(
+      request.tensor.shape
+    )
+    request_id = request.request_id
+    inference_state = request.inference_state
 
-        result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
-        if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
-        tensor_data = result.tobytes() if result is not None else None
-        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
+    if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
+    tensor_data = result.tobytes() if result is not None else None
+    return (
+      node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype))
+      if result is not None
+      else node_service_pb2.Tensor()
+    )
 
-    async def GetInferenceResult(self, request, context):
-        request_id = request.request_id
-        result = await self.node.get_inference_result(request_id)
-        if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
-        tensor_data = result[0].tobytes() if result[0] is not None else None
-        return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), is_finished=result[1]) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
+  async def GetInferenceResult(self, request, context):
+    request_id = request.request_id
+    result = await self.node.get_inference_result(request_id)
+    if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
+    tensor_data = result[0].tobytes() if result[0] is not None else None
+    return (
+      node_service_pb2.InferenceResult(
+        tensor=node_service_pb2.Tensor(
+          tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)
+        ),
+        is_finished=result[1],
+      )
+      if result[0] is not None
+      else node_service_pb2.InferenceResult(is_finished=result[1])
+    )
 
-    async def CollectTopology(self, request, context):
-        max_depth = request.max_depth
-        visited = set(request.visited)
-        topology = await self.node.collect_topology(visited, max_depth)
-        nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory, flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8)) for node_id, cap in topology.nodes.items()}
-        peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
-        if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
-        return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
+  async def CollectTopology(self, request, context):
+    max_depth = request.max_depth
+    visited = set(request.visited)
+    topology = await self.node.collect_topology(visited, max_depth)
+    nodes = {
+      node_id: node_service_pb2.DeviceCapabilities(
+        model=cap.model,
+        chip=cap.chip,
+        memory=cap.memory,
+        flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
+      )
+      for node_id, cap in topology.nodes.items()
+    }
+    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
+    if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
+    return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 
-    async def SendResult(self, request, context):
-        request_id = request.request_id
-        result = request.result
-        is_finished = request.is_finished
-        if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
-        self.node.on_token.trigger_all(request_id, result, is_finished)
-        return node_service_pb2.Empty()
+  async def SendResult(self, request, context):
+    request_id = request.request_id
+    result = request.result
+    is_finished = request.is_finished
+    if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+    self.node.on_token.trigger_all(request_id, result, is_finished)
+    return node_service_pb2.Empty()
 
-    async def SendOpaqueStatus(self, request, context):
-        request_id = request.request_id
-        status = request.status
-        if DEBUG >= 2: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
-        self.node.on_opaque_status.trigger_all(request_id, status)
-        return node_service_pb2.Empty()
+  async def SendOpaqueStatus(self, request, context):
+    request_id = request.request_id
+    status = request.status
+    if DEBUG >= 2: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
+    self.node.on_opaque_status.trigger_all(request_id, status)
+    return node_service_pb2.Empty()

+ 14 - 13
exo/networking/grpc/test_grpc_discovery.py

@@ -2,20 +2,21 @@ import asyncio
 import unittest
 from .grpc_discovery import GRPCDiscovery
 
+
 class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
-    async def asyncSetUp(self):
-        self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
-        self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
-        await self.node1.start()
-        await self.node2.start()
+  async def asyncSetUp(self):
+    self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
+    self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
+    await self.node1.start()
+    await self.node2.start()
 
-    async def asyncTearDown(self):
-        await self.node1.stop()
-        await self.node2.stop()
+  async def asyncTearDown(self):
+    await self.node1.stop()
+    await self.node2.stop()
 
-    async def test_discovery(self):
-        await asyncio.sleep(4)
+  async def test_discovery(self):
+    await asyncio.sleep(4)
 
-        # Check discovered peers
-        print("Node1 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
-        print("Node2 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))
+    # Check discovered peers
+    print("Node1 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
+    print("Node2 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))

+ 44 - 39
exo/networking/peer_handle.py

@@ -5,43 +5,48 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 
-class PeerHandle(ABC):
-    @abstractmethod
-    def id(self) -> str:
-        pass
-
-    @abstractmethod
-    def device_capabilities(self) -> DeviceCapabilities:
-        pass
-
-    @abstractmethod
-    async def connect(self) -> None:
-        pass
-
-    @abstractmethod
-    async def is_connected(self) -> bool:
-        pass
-
-    @abstractmethod
-    async def disconnect(self) -> None:
-        pass
 
-    @abstractmethod
-    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
-        pass
-
-    @abstractmethod
-    async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
-        pass
-
-    @abstractmethod
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        pass
-
-    @abstractmethod
-    async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
-        pass
-
-    @abstractmethod
-    async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-        pass
+class PeerHandle(ABC):
+  @abstractmethod
+  def id(self) -> str:
+    pass
+
+  @abstractmethod
+  def device_capabilities(self) -> DeviceCapabilities:
+    pass
+
+  @abstractmethod
+  async def connect(self) -> None:
+    pass
+
+  @abstractmethod
+  async def is_connected(self) -> bool:
+    pass
+
+  @abstractmethod
+  async def disconnect(self) -> None:
+    pass
+
+  @abstractmethod
+  async def send_prompt(
+    self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
+  ) -> Optional[np.array]:
+    pass
+
+  @abstractmethod
+  async def send_tensor(
+    self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None
+  ) -> Optional[np.array]:
+    pass
+
+  @abstractmethod
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    pass
+
+  @abstractmethod
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+    pass
+
+  @abstractmethod
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    pass

+ 7 - 6
exo/networking/server.py

@@ -1,10 +1,11 @@
 from abc import ABC, abstractmethod
 
+
 class Server(ABC):
-    @abstractmethod
-    async def start(self) -> None:
-        pass
+  @abstractmethod
+  async def start(self) -> None:
+    pass
 
-    @abstractmethod
-    async def stop(self) -> None:
-        pass
+  @abstractmethod
+  async def stop(self) -> None:
+    pass

+ 43 - 38
exo/orchestration/node.py

@@ -5,42 +5,47 @@ from exo.helpers import AsyncCallbackSystem
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 
+
 class Node(ABC):
-    @abstractmethod
-    async def start(self, wait_for_peers: int = 0) -> None:
-        pass
-
-    @abstractmethod
-    async def stop(self) -> None:
-        pass
-
-    @abstractmethod
-    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        pass
-
-    @abstractmethod
-    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        pass
-
-    @abstractmethod
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        pass
-
-    @abstractmethod
-    async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
-        pass
-
-    @property
-    @abstractmethod
-    def current_topology(self) -> Topology:
-        pass
-
-    @property
-    @abstractmethod
-    def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
-        pass
-
-    @property
-    @abstractmethod
-    def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
-        pass
+  @abstractmethod
+  async def start(self, wait_for_peers: int = 0) -> None:
+    pass
+
+  @abstractmethod
+  async def stop(self) -> None:
+    pass
+
+  @abstractmethod
+  async def process_prompt(
+    self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
+  ) -> Optional[np.ndarray]:
+    pass
+
+  @abstractmethod
+  async def process_tensor(
+    self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None
+  ) -> Optional[np.ndarray]:
+    pass
+
+  @abstractmethod
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    pass
+
+  @abstractmethod
+  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
+    pass
+
+  @property
+  @abstractmethod
+  def current_topology(self) -> Topology:
+    pass
+
+  @property
+  @abstractmethod
+  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+    pass
+
+  @property
+  @abstractmethod
+  def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
+    pass

+ 405 - 266
exo/orchestration/standard_node.py

@@ -14,271 +14,410 @@ from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 
+
 class StandardNode(Node):
-    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None, disable_tui: Optional[bool] = False):
-        self.id = id
-        self.inference_engine = inference_engine
-        self.server = server
-        self.discovery = discovery
-        self.partitioning_strategy = partitioning_strategy
-        self.peers: List[PeerHandle] = {}
-        self.topology: Topology = Topology()
-        self.device_capabilities = device_capabilities()
-        self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-        self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) if not disable_tui else None
-        self.max_generate_tokens = max_generate_tokens
-        self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
-        self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
-        self._on_opaque_status.register("node_status").on_next(self.on_node_status)
-
-    def on_node_status(self, request_id, opaque_status):
-        try:
-            status_data = json.loads(opaque_status)
-            if status_data.get("type", "") == "node_status":
-                if status_data.get("status", "").startswith("start_"):
-                    self.current_topology.active_node_id = status_data.get("node_id")
-                elif status_data.get("status", "").startswith("end_"):
-                    if status_data.get("node_id") == self.current_topology.active_node_id:
-                        self.current_topology.active_node_id = None
-            if self.topology_viz: self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
-        except json.JSONDecodeError:
-            pass
-
-    async def start(self, wait_for_peers: int = 0) -> None:
-        await self.server.start()
-        await self.discovery.start()
-        await self.update_peers(wait_for_peers)
+  def __init__(
+    self,
+    id: str,
+    server: Server,
+    inference_engine: InferenceEngine,
+    discovery: Discovery,
+    partitioning_strategy: PartitioningStrategy = None,
+    max_generate_tokens: int = 256,
+    chatgpt_api_endpoint: Optional[str] = None,
+    web_chat_url: Optional[str] = None,
+    disable_tui: Optional[bool] = False,
+  ):
+    self.id = id
+    self.inference_engine = inference_engine
+    self.server = server
+    self.discovery = discovery
+    self.partitioning_strategy = partitioning_strategy
+    self.peers: List[PeerHandle] = {}
+    self.topology: Topology = Topology()
+    self.device_capabilities = device_capabilities()
+    self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.topology_viz = (
+      TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url)
+      if not disable_tui
+      else None
+    )
+    self.max_generate_tokens = max_generate_tokens
+    self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
+    self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
+    self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+
+  def on_node_status(self, request_id, opaque_status):
+    try:
+      status_data = json.loads(opaque_status)
+      if status_data.get("type", "") == "node_status":
+        if status_data.get("status", "").startswith("start_"):
+          self.current_topology.active_node_id = status_data.get("node_id")
+        elif status_data.get("status", "").startswith("end_"):
+          if status_data.get("node_id") == self.current_topology.active_node_id:
+            self.current_topology.active_node_id = None
+      if self.topology_viz:
+        self.topology_viz.update_visualization(
+          self.current_topology, self.partitioning_strategy.partition(self.current_topology)
+        )
+    except json.JSONDecodeError:
+      pass
+
+  async def start(self, wait_for_peers: int = 0) -> None:
+    await self.server.start()
+    await self.discovery.start()
+    await self.update_peers(wait_for_peers)
+    await self.collect_topology()
+    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+    asyncio.create_task(self.periodic_topology_collection(5))
+
+  async def stop(self) -> None:
+    await self.discovery.stop()
+    await self.server.stop()
+
+  async def process_prompt(
+    self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
+  ) -> Optional[np.ndarray]:
+    shard = self.get_current_shard(base_shard)
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps(
+          {
+            "type": "node_status",
+            "node_id": self.id,
+            "status": "start_process_prompt",
+            "base_shard": base_shard.to_dict(),
+            "shard": shard.to_dict(),
+            "prompt": prompt,
+            "inference_state": inference_state,
+            "request_id": request_id,
+          }
+        ),
+      )
+    )
+    start_time = time.perf_counter_ns()
+    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
+    end_time = time.perf_counter_ns()
+    elapsed_time_ns = end_time - start_time
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps(
+          {
+            "type": "node_status",
+            "node_id": self.id,
+            "status": "end_process_prompt",
+            "base_shard": base_shard.to_dict(),
+            "shard": shard.to_dict(),
+            "prompt": prompt,
+            "inference_state": inference_state,
+            "request_id": request_id,
+            "elapsed_time_ns": elapsed_time_ns,
+            "result_size": resp.size if resp is not None else 0,
+          }
+        ),
+      )
+    )
+    return resp
+
+  async def _process_prompt(
+    self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None
+  ) -> Optional[np.ndarray]:
+    if request_id is None:
+      request_id = str(uuid.uuid4())
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    shard = self.get_current_shard(base_shard)
+
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+    if shard.start_layer != 0:
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      await self.forward_to_next_shard(shard, prompt, request_id)
+      return
+
+    result, inference_state, is_finished = await self.inference_engine.infer_prompt(
+      request_id, shard, prompt, inference_state=inference_state
+    )
+    is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if is_finished:
+      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+    asyncio.create_task(
+      self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)
+    )  # TODO: this is n^2 communication complexity
+
+    if result.size == 1:
+      self.buffered_token_output[request_id][0].append(result.item())
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+
+    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+
+    if not is_finished:
+      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
+
+    return (
+      np.array(self.buffered_token_output[request_id][0])
+      if len(self.buffered_token_output[request_id][0]) > 0
+      else None
+    )
+
+  async def process_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None,
+  ) -> Optional[np.ndarray]:
+    shard = self.get_current_shard(base_shard)
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps(
+          {
+            "type": "node_status",
+            "node_id": self.id,
+            "status": "start_process_tensor",
+            "base_shard": base_shard.to_dict(),
+            "shard": shard.to_dict(),
+            "tensor_size": tensor.size,
+            "tensor_shape": tensor.shape,
+            "request_id": request_id,
+            "inference_state": inference_state,
+          }
+        ),
+      )
+    )
+    start_time = time.perf_counter_ns()
+    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
+    end_time = time.perf_counter_ns()
+    elapsed_time_ns = end_time - start_time
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps(
+          {
+            "type": "node_status",
+            "node_id": self.id,
+            "status": "end_process_tensor",
+            "base_shard": base_shard.to_dict(),
+            "shard": shard.to_dict(),
+            "request_id": request_id,
+            "elapsed_time_ns": elapsed_time_ns,
+            "result_size": resp.size if resp is not None else 0,
+          }
+        ),
+      )
+    )
+    return resp
+
+  async def _process_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None,
+  ) -> Optional[np.ndarray]:
+    if request_id is None:
+      request_id = str(uuid.uuid4())
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    shard = self.get_current_shard(base_shard)
+
+    try:
+      if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
+      result, inference_state, is_finished = await self.inference_engine.infer_tensor(
+        request_id, shard, tensor, inference_state=inference_state
+      )
+      is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+      if is_finished:
+        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+      asyncio.create_task(
+        self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)
+      )  # TODO: this is n^2 communication complexity
+
+      if result.size == 1:  # we got a new token out
+        self.buffered_token_output[request_id][0].append(result.item())
+        self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+
+      if not is_finished:
+        asyncio.create_task(
+          self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state)
+        )
+
+      return (
+        np.array(self.buffered_token_output[request_id][0])
+        if len(self.buffered_token_output[request_id][0]) > 0
+        else None
+      )
+    except Exception as e:
+      print(f"Error processing tensor for shard {shard}: {e}")
+      import traceback
+
+      traceback.print_exc()
+      return None
+
+  async def forward_to_next_shard(
+    self,
+    base_shard: Shard,
+    tensor_or_prompt: Union[np.ndarray, str],
+    request_id: str,
+    inference_state: Optional[str] = None,
+  ) -> None:
+    if not self.partitioning_strategy:
+      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
+      return
+    shard = self.get_current_shard(base_shard)
+
+    partitions = self.partitioning_strategy.partition(self.topology)
+    shards = map_partitions_to_shards(
+      self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id
+    )
+    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+    if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
+    if current_partition_index is not None:
+      next_partition_index = (current_partition_index + 1) % len(partitions)
+      next_partition: Partition = partitions[next_partition_index]
+      next_shard = shards[next_partition_index]
+      if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
+
+      if next_partition.node_id == self.id:
+        if isinstance(tensor_or_prompt, np.ndarray):
+          await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+        else:
+          await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+        return
+
+      target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {next_partition} not found")
+
+      if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
+
+      if isinstance(tensor_or_prompt, np.ndarray):
+        await target_peer.send_tensor(
+          next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state
+        )
+      else:
+        await target_peer.send_prompt(
+          next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state
+        )
+
+  def get_current_shard(self, base_shard: Shard) -> Shard:
+    partitions = self.partitioning_strategy.partition(self.topology)
+    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
+    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+    if current_partition_index is None:
+      raise ValueError(f"No current partition found for node: {self.id}")
+    return shards[current_partition_index]
+
+  async def update_peers(self, wait_for_peers: int = 0) -> None:
+    self.peers = await self.discovery.discover_peers(wait_for_peers)
+    if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
+    if DEBUG >= 2: print("Connecting to new peers...")
+    for peer in self.peers:
+      is_connected = await peer.is_connected()
+      if DEBUG >= 2 and is_connected:
+        print(f"Already connected to {peer.id()}: {is_connected}")
+      if not is_connected:
+        await peer.connect()
+        if DEBUG >= 0: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
+
+  async def periodic_topology_collection(self, interval: int):
+    while True:
+      await asyncio.sleep(interval)
+      try:
+        await self.update_peers()
         await self.collect_topology()
-        if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-        asyncio.create_task(self.periodic_topology_collection(5))
-
-    async def stop(self) -> None:
-        await self.discovery.stop()
-        await self.server.stop()
-
-    async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        shard = self.get_current_shard(base_shard)
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
-        start_time = time.perf_counter_ns()
-        resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
-        end_time = time.perf_counter_ns()
-        elapsed_time_ns = end_time - start_time
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
-        return resp
-
-    async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        if request_id is None:
-            request_id = str(uuid.uuid4())
-        if request_id not in self.buffered_token_output:
-            self.buffered_token_output[request_id] = ([], False)
-        shard = self.get_current_shard(base_shard)
-
-        if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
-        if shard.start_layer != 0:
-            if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-            await self.forward_to_next_shard(shard, prompt, request_id)
-            return
-
-        result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
-        is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-        if is_finished:
-            self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-        asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
-
-        if result.size == 1:
-            self.buffered_token_output[request_id][0].append(result.item())
-            self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-
-        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-        if not is_finished:
-            asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-        return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
-
-    async def process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        shard = self.get_current_shard(base_shard)
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_tensor", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "tensor_size": tensor.size, "tensor_shape": tensor.shape, "request_id": request_id, "inference_state": inference_state})))
-        start_time = time.perf_counter_ns()
-        resp = await self._process_tensor(shard, tensor, request_id, inference_state)
-        end_time = time.perf_counter_ns()
-        elapsed_time_ns = end_time - start_time
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
-        return resp
-
-    async def _process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        if request_id is None:
-            request_id = str(uuid.uuid4())
-        if request_id not in self.buffered_token_output:
-            self.buffered_token_output[request_id] = ([], False)
-        shard = self.get_current_shard(base_shard)
-
-        try:
-            if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-            result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-            is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-            if is_finished:
-                self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-            asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
-
-            if result.size == 1:  # we got a new token out
-                self.buffered_token_output[request_id][0].append(result.item())
-                self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-            if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-            if not is_finished:
-                asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-            return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
-        except Exception as e:
-            print(f"Error processing tensor for shard {shard}: {e}")
-            import traceback
-            traceback.print_exc()
-            return None
-
-    async def forward_to_next_shard(self, base_shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, inference_state: Optional[str] = None) -> None:
-        if not self.partitioning_strategy:
-            if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-            return
-        shard = self.get_current_shard(base_shard)
-
-        partitions = self.partitioning_strategy.partition(self.topology)
-        shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
-        current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-        if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
-        if current_partition_index is not None:
-            next_partition_index = (current_partition_index + 1) % len(partitions)
-            next_partition: Partition = partitions[next_partition_index]
-            next_shard = shards[next_partition_index]
-            if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
-
-            if next_partition.node_id == self.id:
-                if isinstance(tensor_or_prompt, np.ndarray):
-                    await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-                else:
-                    await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-                return
-
-            target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-            if not target_peer:
-                raise ValueError(f"Peer for {next_partition} not found")
-
-            if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
-
-            if isinstance(tensor_or_prompt, np.ndarray):
-                await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
-            else:
-                await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
-
-    def get_current_shard(self, base_shard: Shard) -> Shard:
-        partitions = self.partitioning_strategy.partition(self.topology)
-        shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
-        current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-        if current_partition_index is None:
-            raise ValueError(f"No current partition found for node: {self.id}")
-        return shards[current_partition_index]
-
-    async def update_peers(self, wait_for_peers: int = 0) -> None:
-        self.peers = await self.discovery.discover_peers(wait_for_peers)
-        if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
-        if DEBUG >= 2: print("Connecting to new peers...")
-        for peer in self.peers:
-            is_connected = await peer.is_connected()
-            if DEBUG >= 2 and is_connected: print(f"Already connected to {peer.id()}: {is_connected}")
-            if not is_connected:
-                await peer.connect()
-                if DEBUG >= 0: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
-
-    async def periodic_topology_collection(self, interval: int):
-        while True:
-            await asyncio.sleep(interval)
-            try:
-                await self.update_peers()
-                await self.collect_topology()
-            except Exception as e:
-                print(f"Error collecting topology: {e}")
-
-            if DEBUG >= 2: print("Topology collection task executed.")
-            if DEBUG >= 2: print(f"Current topology: {self.topology}")
-
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        if request_id not in self.buffered_token_output:
-            return None, False
-        return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
-
-    async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
-        next_topology = Topology()
-        next_topology.update_node(self.id, self.device_capabilities)
-
-        if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
-
-        prev_visited = visited.copy()
-        visited.update(p.id() for p in self.peers)
-
-        for peer in self.peers:
-            next_topology.update_node(peer.id(), peer.device_capabilities())
-            next_topology.add_edge(self.id, peer.id())
-
-            if peer.id() in prev_visited:
-                if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
-                continue
-
-            if max_depth <= 0:
-                if DEBUG >= 2: print(f"Max depth reached. Skipping...")
-                continue
-
-            try:
-                other_topology = await peer.collect_topology(visited, max_depth = max_depth - 1)
-                if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
-                self.topology.merge(other_topology)
-            except Exception as e:
-                print(f"Error collecting topology from {peer.id()}: {e}")
-
-        next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
-        self.topology = next_topology
-        if self.topology_viz: self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
-        return next_topology
-
-    @property
-    def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
-        return self._on_token
-
-    @property
-    def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
-        return self._on_opaque_status
-
-    def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
-        if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
-        self.on_token.trigger_all(request_id, tokens, is_finished)
-
-    async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-        async def send_result_to_peer(peer):
-            try:
-                await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
-            except asyncio.TimeoutError:
-                print(f"Timeout broadcasting result to {peer.id()}")
-            except Exception as e:
-                print(f"Error broadcasting result to {peer.id()}: {e}")
-                import traceback
-                traceback.print_exc()
-
-        await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
-
-    async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
-        async def send_status_to_peer(peer):
-            try:
-                await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
-            except asyncio.TimeoutError:
-                print(f"Timeout sending opaque status to {peer.id()}")
-            except Exception as e:
-                print(f"Error sending opaque status to {peer.id()}: {e}")
-                import traceback
-                traceback.print_exc()
-
-        await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
-        # in the case of opaque status, we also want to receive our own opaque statuses
-        self.on_opaque_status.trigger_all(request_id, status)
-
-    @property
-    def current_topology(self) -> Topology:
-        return self.topology
+      except Exception as e:
+        print(f"Error collecting topology: {e}")
+
+      if DEBUG >= 2: print("Topology collection task executed.")
+      if DEBUG >= 2: print(f"Current topology: {self.topology}")
+
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    if request_id not in self.buffered_token_output:
+      return None, False
+    return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
+
+  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
+    next_topology = Topology()
+    next_topology.update_node(self.id, self.device_capabilities)
+
+    if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
+
+    prev_visited = visited.copy()
+    visited.update(p.id() for p in self.peers)
+
+    for peer in self.peers:
+      next_topology.update_node(peer.id(), peer.device_capabilities())
+      next_topology.add_edge(self.id, peer.id())
+
+      if peer.id() in prev_visited:
+        if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
+        continue
+
+      if max_depth <= 0:
+        if DEBUG >= 2: print(f"Max depth reached. Skipping...")
+        continue
+
+      try:
+        other_topology = await peer.collect_topology(visited, max_depth=max_depth - 1)
+        if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
+        self.topology.merge(other_topology)
+      except Exception as e:
+        print(f"Error collecting topology from {peer.id()}: {e}")
+
+    next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
+    self.topology = next_topology
+    if self.topology_viz:
+      self.topology_viz.update_visualization(
+        self.current_topology, self.partitioning_strategy.partition(self.current_topology)
+      )
+    return next_topology
+
+  @property
+  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+    return self._on_token
+
+  @property
+  def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
+    return self._on_opaque_status
+
+  def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
+    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
+    self.on_token.trigger_all(request_id, tokens, is_finished)
+
+  async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    async def send_result_to_peer(peer):
+      try:
+        await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
+      except asyncio.TimeoutError:
+        print(f"Timeout broadcasting result to {peer.id()}")
+      except Exception as e:
+        print(f"Error broadcasting result to {peer.id()}: {e}")
+        import traceback
+
+        traceback.print_exc()
+
+    await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
+
+  async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
+    async def send_status_to_peer(peer):
+      try:
+        await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
+      except asyncio.TimeoutError:
+        print(f"Timeout sending opaque status to {peer.id()}")
+      except Exception as e:
+        print(f"Error sending opaque status to {peer.id()}: {e}")
+        import traceback
+
+        traceback.print_exc()
+
+    await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
+    # in the case of opaque status, we also want to receive our own opaque statuses
+    self.on_opaque_status.trigger_all(request_id, status)
+
+  @property
+  def current_topology(self) -> Topology:
+    return self.topology

+ 51 - 48
exo/orchestration/test_node.py

@@ -5,52 +5,55 @@ import numpy as np
 from .standard_node import StandardNode
 from exo.networking.peer_handle import PeerHandle
 
+
 class TestNode(unittest.IsolatedAsyncioTestCase):
-    def setUp(self):
-        self.mock_inference_engine = AsyncMock()
-        self.mock_server = AsyncMock()
-        self.mock_server.start = AsyncMock()
-        self.mock_server.stop = AsyncMock()
-        self.mock_discovery = AsyncMock()
-        self.mock_discovery.start = AsyncMock()
-        self.mock_discovery.stop = AsyncMock()
-        mock_peer1 = Mock(spec=PeerHandle)
-        mock_peer1.id.return_value = "peer1"
-        mock_peer2 = Mock(spec=PeerHandle)
-        mock_peer2.id.return_value = "peer2"
-        self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
-
-        self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
-
-    async def asyncSetUp(self):
-        await self.node.start()
-
-    async def asyncTearDown(self):
-        await self.node.stop()
-
-    async def test_node_initialization(self):
-        self.assertEqual(self.node.node_id, "test_node")
-        self.assertEqual(self.node.host, "localhost")
-        self.assertEqual(self.node.port, 50051)
-
-    async def test_node_start(self):
-        self.mock_server.start.assert_called_once_with("localhost", 50051)
-
-    async def test_node_stop(self):
-        await self.node.stop()
-        self.mock_server.stop.assert_called_once()
-
-    async def test_discover_and_connect_to_peers(self):
-        await self.node.discover_and_connect_to_peers()
-        self.assertEqual(len(self.node.peers), 2)
-        self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
-        self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
-
-    async def test_process_tensor_calls_inference_engine(self):
-        mock_peer = Mock()
-        self.node.peers = [mock_peer]
-
-        input_tensor = np.array([69, 1, 2])
-        await self.node.process_tensor(input_tensor, None)
-
-        self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)
+  def setUp(self):
+    self.mock_inference_engine = AsyncMock()
+    self.mock_server = AsyncMock()
+    self.mock_server.start = AsyncMock()
+    self.mock_server.stop = AsyncMock()
+    self.mock_discovery = AsyncMock()
+    self.mock_discovery.start = AsyncMock()
+    self.mock_discovery.stop = AsyncMock()
+    mock_peer1 = Mock(spec=PeerHandle)
+    mock_peer1.id.return_value = "peer1"
+    mock_peer2 = Mock(spec=PeerHandle)
+    mock_peer2.id.return_value = "peer2"
+    self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
+
+    self.node = StandardNode(
+      "test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery
+    )
+
+  async def asyncSetUp(self):
+    await self.node.start()
+
+  async def asyncTearDown(self):
+    await self.node.stop()
+
+  async def test_node_initialization(self):
+    self.assertEqual(self.node.node_id, "test_node")
+    self.assertEqual(self.node.host, "localhost")
+    self.assertEqual(self.node.port, 50051)
+
+  async def test_node_start(self):
+    self.mock_server.start.assert_called_once_with("localhost", 50051)
+
+  async def test_node_stop(self):
+    await self.node.stop()
+    self.mock_server.stop.assert_called_once()
+
+  async def test_discover_and_connect_to_peers(self):
+    await self.node.discover_and_connect_to_peers()
+    self.assertEqual(len(self.node.peers), 2)
+    self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
+    self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
+
+  async def test_process_tensor_calls_inference_engine(self):
+    mock_peer = Mock()
+    self.node.peers = [mock_peer]
+
+    input_tensor = np.array([69, 1, 2])
+    await self.node.process_tensor(input_tensor, None)
+
+    self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)

+ 19 - 17
exo/stats/metrics.py

@@ -4,25 +4,27 @@ import json
 from typing import List
 
 # Create metrics to track time spent and requests made.
-PROCESS_PROMPT_COUNTER = Counter('process_prompt_total', 'Total number of prompts processed', ['node_id'])
-PROCESS_TENSOR_COUNTER = Counter('process_tensor_total', 'Total number of tensors processed', ['node_id'])
-PROCESS_TENSOR_TIME = Histogram('process_tensor_seconds', 'Time spent processing tensor', ['node_id'])
+PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
+PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
+PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])
+
 
 def start_metrics_server(node: Node, port: int):
-    start_http_server(port)
+  start_http_server(port)
 
-    def _on_opaque_status(request_id, opaque_status: str):
-        status_data = json.loads(opaque_status)
-        type = status_data.get("type", "")
-        node_id = status_data.get("node_id", "")
-        if type != "node_status": return
-        status = status_data.get("status", "")
+  def _on_opaque_status(request_id, opaque_status: str):
+    status_data = json.loads(opaque_status)
+    type = status_data.get("type", "")
+    node_id = status_data.get("node_id", "")
+    if type != "node_status":
+      return
+    status = status_data.get("status", "")
 
-        if status == "end_process_prompt":
-            PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
-        elif status == "end_process_tensor":
-            elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
-            PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
-            PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns / 1e9)  # Convert ns to seconds
+    if status == "end_process_prompt":
+      PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
+    elif status == "end_process_tensor":
+      elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
+      PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
+      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns / 1e9)  # Convert ns to seconds
 
-    node.on_opaque_status.register("stats").on_next(_on_opaque_status)
+  node.on_opaque_status.register("stats").on_next(_on_opaque_status)

+ 33 - 30
exo/test_callbacks.py

@@ -2,46 +2,49 @@ import asyncio
 from typing import Any, Callable
 from exo.helpers import AsyncCallbackSystem, AsyncCallback
 
+
 # Usage example
 async def main() -> None:
-    callback_system = AsyncCallbackSystem[str, Any]()
+  callback_system = AsyncCallbackSystem[str, Any]()
+
+  # Register callbacks
+  callback1 = callback_system.register("callback1")
+  callback2 = callback_system.register("callback2")
+
+  def on_next_callback(name: str) -> Callable[..., None]:
+    def callback(*args: Any) -> None:
+      print(f"{name} received values: {args}")
 
-    # Register callbacks
-    callback1 = callback_system.register("callback1")
-    callback2 = callback_system.register("callback2")
+    return callback
 
-    def on_next_callback(name: str) -> Callable[..., None]:
-        def callback(*args: Any) -> None:
-            print(f"{name} received values: {args}")
-        return callback
+  callback1.on_next(on_next_callback("Callback1"))
+  callback2.on_next(on_next_callback("Callback2"))
 
-    callback1.on_next(on_next_callback("Callback1"))
-    callback2.on_next(on_next_callback("Callback2"))
+  async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None:
+    try:
+      result = await callback.wait(condition, timeout=2)
+      print(f"{name} wait completed with result: {result}")
+    except asyncio.TimeoutError:
+      print(f"{name} wait timed out")
 
-    async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None:
-        try:
-            result = await callback.wait(condition, timeout=2)
-            print(f"{name} wait completed with result: {result}")
-        except asyncio.TimeoutError:
-            print(f"{name} wait timed out")
+  # Trigger all callbacks at once
+  callback_system.trigger_all("Hello", 42, True)
 
-    # Trigger all callbacks at once
-    callback_system.trigger_all("Hello", 42, True)
+  # Wait for all callbacks with different conditions
+  await asyncio.gather(
+    wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0),
+    wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True),
+  )
 
-    # Wait for all callbacks with different conditions
-    await asyncio.gather(
-        wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0),
-        wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True)
-    )
+  # Trigger individual callback
+  callback_system.trigger("callback2", "World", -10, False)
 
-    # Trigger individual callback
-    callback_system.trigger("callback2", "World", -10, False)
+  # Demonstrate timeout
+  new_callback = callback_system.register("new_callback")
+  new_callback.on_next(on_next_callback("NewCallback"))
+  await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100)
 
-    # Demonstrate timeout
-    new_callback = callback_system.register("new_callback")
-    new_callback.on_next(on_next_callback("NewCallback"))
-    await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100)
+  callback_system.trigger("callback2", "World", 200, False)
 
-    callback_system.trigger("callback2", "World", 200, False)
 
 asyncio.run(main())

+ 156 - 130
exo/topology/device_capabilities.py

@@ -5,153 +5,179 @@ import psutil
 
 TFLOPS = 1.00
 
+
 @dataclass
 class DeviceFlops:
-    # units of TFLOPS
-    fp32: float
-    fp16: float
-    int8: float
+  # units of TFLOPS
+  fp32: float
+  fp16: float
+  int8: float
+
+  def __str__(self):
+    return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
 
-    def __str__(self):
-        return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
+  def to_dict(self):
+    return asdict(self)
 
-    def to_dict(self):
-        return asdict(self)
 
 @dataclass
 class DeviceCapabilities:
-    model: str
-    chip: str
-    memory: int
-    flops: DeviceFlops
+  model: str
+  chip: str
+  memory: int
+  flops: DeviceFlops
+
+  def __str__(self):
+    return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
 
-    def __str__(self):
-        return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
+  def __post_init__(self):
+    if isinstance(self.flops, dict):
+      self.flops = DeviceFlops(**self.flops)
 
-    def __post_init__(self):
-        if isinstance(self.flops, dict):
-            self.flops = DeviceFlops(**self.flops)
+  def to_dict(self):
+    return {"model": self.model, "chip": self.chip, "memory": self.memory, "flops": self.flops.to_dict()}
 
-    def to_dict(self):
-        return {
-            'model': self.model,
-            'chip': self.chip,
-            'memory': self.memory,
-            'flops': self.flops.to_dict()
-        }
 
-UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
+UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(
+  model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)
+)
 
 CHIP_FLOPS = {
-    # Source: https://www.cpu-monkey.com
-    # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
-    ### M chips
-    "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS),
-    "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS),
-    "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS),
-    "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS),
-    "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-    "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS),
-    "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
-    "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
-    "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-    "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
-    "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
-    "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-    ### A chips
-    "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
-    "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
-    "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS),
-    "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS),
-    "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS),
-    ### NVIDIA GPUs
-    #RTX 40 series
-    "Nvidia GeForce RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS),
-    "Nvidia GeForce RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS),
-    "Nvidia GeForce RTX 4080 Super": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS),
-    "Nvidia GeForce RTX 4070 Ti Super": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
-    "Nvidia GeForce RTX 4070 Ti": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS),
-    "Nvidia GeForce RTX 4070 Super": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
-    "Nvidia GeForce RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
-    "Nvidia GeForce RTX 4060 Ti 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
-    #RTX 30 series
-    "Nvidia GeForce RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
-    "Nvidia GeForce RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
-    "Nvidia GeForce RTX 3060 Ti": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
-    "Nvidia GeForce RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS),
-    "Nvidia GeForce RTX 3070 Ti": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS),
-    "Nvidia GeForce RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS),
-    "Nvidia GeForce RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS),
-    "Nvidia GeForce RTX 3080 Ti": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS),
-    "Nvidia GeForce RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS),
-    "Nvidia GeForce RTX 3090 Ti": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
-        # ... add more devices if needed ...
-    ### AMD GPUs
-    # RX 6000 series
-    "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS),
-    "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS),
-    "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS),
-    "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS),
-    "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS),
-    "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS),
-    "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS),
-    "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS),
-    "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS),
-    # RX 7000 series
-    "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS),
-    "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS),
-    "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS),
-    "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS),
-    "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS),
-    "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
-    # ... add more devices if needed ...
-    ### Qualcomm embedded chips: TODO
+  # Source: https://www.cpu-monkey.com
+  # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
+  ### M chips
+  "Apple M1": DeviceFlops(fp32=2.29 * TFLOPS, fp16=4.58 * TFLOPS, int8=9.16 * TFLOPS),
+  "Apple M1 Pro": DeviceFlops(fp32=5.30 * TFLOPS, fp16=10.60 * TFLOPS, int8=21.20 * TFLOPS),
+  "Apple M1 Max": DeviceFlops(fp32=10.60 * TFLOPS, fp16=21.20 * TFLOPS, int8=42.40 * TFLOPS),
+  "Apple M1 Ultra": DeviceFlops(fp32=21.20 * TFLOPS, fp16=42.40 * TFLOPS, int8=84.80 * TFLOPS),
+  "Apple M2": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
+  "Apple M2 Pro": DeviceFlops(fp32=5.68 * TFLOPS, fp16=11.36 * TFLOPS, int8=22.72 * TFLOPS),
+  "Apple M2 Max": DeviceFlops(fp32=13.49 * TFLOPS, fp16=26.98 * TFLOPS, int8=53.96 * TFLOPS),
+  "Apple M2 Ultra": DeviceFlops(fp32=26.98 * TFLOPS, fp16=53.96 * TFLOPS, int8=107.92 * TFLOPS),
+  "Apple M3": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS),
+  "Apple M3 Pro": DeviceFlops(fp32=4.97 * TFLOPS, fp16=9.94 * TFLOPS, int8=19.88 * TFLOPS),
+  "Apple M4": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
+  ### A chips
+  "Apple A13 Bionic": DeviceFlops(fp32=0.69 * TFLOPS, fp16=1.38 * TFLOPS, int8=2.76 * TFLOPS),
+  "Apple A14 Bionic": DeviceFlops(fp32=0.75 * TFLOPS, fp16=1.50 * TFLOPS, int8=3.00 * TFLOPS),
+  "Apple A15 Bionic": DeviceFlops(fp32=1.37 * TFLOPS, fp16=2.74 * TFLOPS, int8=5.48 * TFLOPS),
+  "Apple A16 Bionic": DeviceFlops(fp32=1.79 * TFLOPS, fp16=3.58 * TFLOPS, int8=7.16 * TFLOPS),
+  "Apple A17 Pro": DeviceFlops(fp32=2.15 * TFLOPS, fp16=4.30 * TFLOPS, int8=8.60 * TFLOPS),
+  ### NVIDIA GPUs
+  # RTX 40 series
+  "Nvidia GeForce RTX 4090": DeviceFlops(fp32=82.58 * TFLOPS, fp16=165.16 * TFLOPS, int8=330.32 * TFLOPS),
+  "Nvidia GeForce RTX 4080": DeviceFlops(fp32=48.74 * TFLOPS, fp16=97.48 * TFLOPS, int8=194.96 * TFLOPS),
+  "Nvidia GeForce RTX 4080 Super": DeviceFlops(fp32=52.0 * TFLOPS, fp16=104.0 * TFLOPS, int8=208.0 * TFLOPS),
+  "Nvidia GeForce RTX 4070 Ti Super": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
+  "Nvidia GeForce RTX 4070 Ti": DeviceFlops(fp32=39.43 * TFLOPS, fp16=78.86 * TFLOPS, int8=157.72 * TFLOPS),
+  "Nvidia GeForce RTX 4070 Super": DeviceFlops(fp32=30.0 * TFLOPS, fp16=60.0 * TFLOPS, int8=120.0 * TFLOPS),
+  "Nvidia GeForce RTX 4070": DeviceFlops(fp32=29.0 * TFLOPS, fp16=58.0 * TFLOPS, int8=116.0 * TFLOPS),
+  "Nvidia GeForce RTX 4060 Ti 16GB": DeviceFlops(fp32=22.0 * TFLOPS, fp16=44.0 * TFLOPS, int8=88.0 * TFLOPS),
+  # RTX 30 series
+  "Nvidia GeForce RTX 3050": DeviceFlops(fp32=9.11 * TFLOPS, fp16=18.22 * TFLOPS, int8=36.44 * TFLOPS),
+  "Nvidia GeForce RTX 3060": DeviceFlops(fp32=13.0 * TFLOPS, fp16=26.0 * TFLOPS, int8=52.0 * TFLOPS),
+  "Nvidia GeForce RTX 3060 Ti": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
+  "Nvidia GeForce RTX 3070": DeviceFlops(fp32=20.3 * TFLOPS, fp16=40.6 * TFLOPS, int8=81.2 * TFLOPS),
+  "Nvidia GeForce RTX 3070 Ti": DeviceFlops(fp32=21.8 * TFLOPS, fp16=43.6 * TFLOPS, int8=87.2 * TFLOPS),
+  "Nvidia GeForce RTX 3080 (10 GB)": DeviceFlops(fp32=29.8 * TFLOPS, fp16=59.6 * TFLOPS, int8=119.2 * TFLOPS),
+  "Nvidia GeForce RTX 3080 (12 GB)": DeviceFlops(fp32=30.6 * TFLOPS, fp16=61.2 * TFLOPS, int8=122.4 * TFLOPS),
+  "Nvidia GeForce RTX 3080 Ti": DeviceFlops(fp32=34.1 * TFLOPS, fp16=68.2 * TFLOPS, int8=136.4 * TFLOPS),
+  "Nvidia GeForce RTX 3090": DeviceFlops(fp32=35.6 * TFLOPS, fp16=71.2 * TFLOPS, int8=142.4 * TFLOPS),
+  "Nvidia GeForce RTX 3090 Ti": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
+  # ... add more devices if needed ...
+  ### AMD GPUs
+  # RX 6000 series
+  "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04 * TFLOPS, fp16=46.08 * TFLOPS, int8=92.16 * TFLOPS),
+  "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74 * TFLOPS, fp16=41.48 * TFLOPS, int8=82.96 * TFLOPS),
+  "AMD Radeon RX 6800": DeviceFlops(fp32=16.17 * TFLOPS, fp16=32.34 * TFLOPS, int8=64.68 * TFLOPS),
+  "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21 * TFLOPS, fp16=26.42 * TFLOPS, int8=52.84 * TFLOPS),
+  "AMD Radeon RX 6700": DeviceFlops(fp32=11.4 * TFLOPS, fp16=22.8 * TFLOPS, int8=45.6 * TFLOPS),
+  "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6 * TFLOPS, fp16=21.2 * TFLOPS, int8=42.4 * TFLOPS),
+  "AMD Radeon RX 6600": DeviceFlops(fp32=8.93 * TFLOPS, fp16=17.86 * TFLOPS, int8=35.72 * TFLOPS),
+  "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77 * TFLOPS, fp16=11.54 * TFLOPS, int8=23.08 * TFLOPS),
+  "AMD Radeon RX 6400": DeviceFlops(fp32=3.57 * TFLOPS, fp16=7.14 * TFLOPS, int8=14.28 * TFLOPS),
+  # RX 7000 series
+  "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4 * TFLOPS, fp16=122.8 * TFLOPS, int8=245.6 * TFLOPS),
+  "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4 * TFLOPS, fp16=106.8 * TFLOPS, int8=213.6 * TFLOPS),
+  "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6 * TFLOPS, fp16=85.2 * TFLOPS, int8=170.4 * TFLOPS),
+  "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2 * TFLOPS, fp16=68.4 * TFLOPS, int8=136.8 * TFLOPS),
+  "AMD Radeon RX 7600": DeviceFlops(fp32=21.5 * TFLOPS, fp16=43.0 * TFLOPS, int8=86.0 * TFLOPS),
+  "AMD Radeon RX 7500": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
+  # ... add more devices if needed ...
+  ### Qualcomm embedded chips: TODO
 }
 
+
 def device_capabilities() -> DeviceCapabilities:
-    if psutil.MACOS:
-        return mac_device_capabilities()
-    elif psutil.LINUX:
-        return linux_device_capabilities()
-    else:
-        return DeviceCapabilities(model=f"Unknown Device", chip=f"Unknown Chip", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
+  if psutil.MACOS:
+    return mac_device_capabilities()
+  elif psutil.LINUX:
+    return linux_device_capabilities()
+  else:
+    return DeviceCapabilities(
+      model=f"Unknown Device",
+      chip=f"Unknown Chip",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+
 
 def mac_device_capabilities() -> DeviceCapabilities:
-    # Fetch the model of the Mac using system_profiler
-    model = subprocess.check_output(['system_profiler', 'SPHardwareDataType']).decode('utf-8')
-    model_line = next((line for line in model.split('\n') if "Model Name" in line), None)
-    model_id = model_line.split(': ')[1] if model_line else "Unknown Model"
-    chip_line = next((line for line in model.split('\n') if "Chip" in line), None)
-    chip_id = chip_line.split(': ')[1] if chip_line else "Unknown Chip"
-    memory_line = next((line for line in model.split('\n') if "Memory" in line), None)
-    memory_str = memory_line.split(': ')[1] if memory_line else "Unknown Memory"
-    memory_units = memory_str.split()
-    memory_value = int(memory_units[0])
-    if memory_units[1] == "GB":
-        memory = memory_value * 1024
-    else:
-        memory = memory_value
-
-    # Assuming static values for other attributes for demonstration
-    return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
+  # Fetch the model of the Mac using system_profiler
+  model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
+  model_line = next((line for line in model.split("\n") if "Model Name" in line), None)
+  model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
+  chip_line = next((line for line in model.split("\n") if "Chip" in line), None)
+  chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
+  memory_line = next((line for line in model.split("\n") if "Memory" in line), None)
+  memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
+  memory_units = memory_str.split()
+  memory_value = int(memory_units[0])
+  if memory_units[1] == "GB":
+    memory = memory_value * 1024
+  else:
+    memory = memory_value
+
+  # Assuming static values for other attributes for demonstration
+  return DeviceCapabilities(
+    model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))
+  )
+
 
 def linux_device_capabilities() -> DeviceCapabilities:
-    import psutil
-    from tinygrad import Device
-
-    if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
-    if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT=="GPU":
-        import pynvml, pynvml_utils
-        pynvml.nvmlInit()
-        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
-        gpu_name = pynvml.nvmlDeviceGetName(handle)
-        gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
-
-        if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
-
-        return DeviceCapabilities(model=f"Linux Box ({gpu_name})", chip=gpu_name, memory=gpu_memory_info.total // 2**20, flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)))
-    elif Device.DEFAULT == "AMD":
-        # TODO AMD support
-        return DeviceCapabilities(model="Linux Box (AMD)", chip="Unknown AMD", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
-    else:
-        return DeviceCapabilities(model=f"Linux Box (Device: {Device.DEFAULT})", chip=f"Unknown Chip (Device: {Device.DEFAULT})", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
+  import psutil
+  from tinygrad import Device
+
+  if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
+  if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
+    import pynvml, pynvml_utils
+
+    pynvml.nvmlInit()
+    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+    gpu_name = pynvml.nvmlDeviceGetName(handle)
+    gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+
+    if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
+
+    return DeviceCapabilities(
+      model=f"Linux Box ({gpu_name})",
+      chip=gpu_name,
+      memory=gpu_memory_info.total // 2**20,
+      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+  elif Device.DEFAULT == "AMD":
+    # TODO AMD support
+    return DeviceCapabilities(
+      model="Linux Box (AMD)",
+      chip="Unknown AMD",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+  else:
+    return DeviceCapabilities(
+      model=f"Linux Box (Device: {Device.DEFAULT})",
+      chip=f"Unknown Chip (Device: {Device.DEFAULT})",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )

+ 23 - 20
exo/topology/partitioning_strategy.py

@@ -4,34 +4,37 @@ from dataclasses import dataclass
 from .topology import Topology
 from exo.inference.shard import Shard
 
+
 # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
 @dataclass
 class Partition:
-    node_id: str
-    start: float
-    end: float
+  node_id: str
+  start: float
+  end: float
+
 
 class PartitioningStrategy(ABC):
-    @abstractmethod
-    def partition(self, topology: Topology) -> List[Partition]:
-        pass
+  @abstractmethod
+  def partition(self, topology: Topology) -> List[Partition]:
+    pass
+
 
 def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
-    shards = []
-    for i, partition in enumerate(partitions):
-        start_layer = int(partition.start * num_layers)
-        end_layer = int(partition.end * num_layers) - 1
+  shards = []
+  for i, partition in enumerate(partitions):
+    start_layer = int(partition.start * num_layers)
+    end_layer = int(partition.end * num_layers) - 1
 
-        # Ensure the last partition covers up to num_layers - 1
-        if i == len(partitions) - 1:
-            end_layer = num_layers - 1
+    # Ensure the last partition covers up to num_layers - 1
+    if i == len(partitions) - 1:
+      end_layer = num_layers - 1
 
-        # Ensure no empty shards
-        if start_layer <= end_layer:
-            shards.append(Shard(model_id, start_layer, end_layer, num_layers))
+    # Ensure no empty shards
+    if start_layer <= end_layer:
+      shards.append(Shard(model_id, start_layer, end_layer, num_layers))
 
-    # Ensure full coverage
-    if shards and shards[-1].end_layer < num_layers - 1:
-        shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
+  # Ensure full coverage
+  if shards and shards[-1].end_layer < num_layers - 1:
+    shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
 
-    return shards
+  return shards

+ 12 - 11
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -4,15 +4,16 @@ from exo.inference.shard import Shard
 from .topology import Topology
 from .partitioning_strategy import Partition
 
+
 class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
-    def partition(self, topology: Topology) -> List[Partition]:
-        nodes = list(topology.all_nodes())
-        nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)
-        total_memory = sum(node[1].memory for node in nodes)
-        partitions = []
-        start = 0
-        for node in nodes:
-            end = round(start + (node[1].memory / total_memory), 5)
-            partitions.append(Partition(node[0], start, end))
-            start = end
-        return partitions
+  def partition(self, topology: Topology) -> List[Partition]:
+    nodes = list(topology.all_nodes())
+    nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)
+    total_memory = sum(node[1].memory for node in nodes)
+    partitions = []
+    start = 0
+    for node in nodes:
+      end = round(start + (node[1].memory / total_memory), 5)
+      partitions.append(Partition(node[0], start, end))
+      start = end
+    return partitions

+ 72 - 64
exo/topology/test_device_capabilities.py

@@ -2,82 +2,90 @@ import unittest
 from unittest.mock import patch
 from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS
 
+
 class TestMacDeviceCapabilities(unittest.TestCase):
-    @patch('subprocess.check_output')
-    def test_mac_device_capabilities(self, mock_check_output):
-        # Mock the subprocess output
-        mock_check_output.return_value = b"""
+  @patch("subprocess.check_output")
+  def test_mac_device_capabilities(self, mock_check_output):
+    # Mock the subprocess output
+    mock_check_output.return_value = b"""
 Hardware:
 
-    Hardware Overview:
+Hardware Overview:
 
-        Model Name: MacBook Pro
-        Model Identifier: Mac15,9
-        Model Number: Z1CM000EFB/A
-        Chip: Apple M3 Max
-        Total Number of Cores: 16 (12 performance and 4 efficiency)
-        Memory: 128 GB
-        System Firmware Version: 10000.000.0
-        OS Loader Version: 10000.000.0
-        Serial Number (system): XXXXXXXXXX
-        Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
-        Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
-        Activation Lock Status: Enabled
-        """
+Model Name: MacBook Pro
+Model Identifier: Mac15,9
+Model Number: Z1CM000EFB/A
+Chip: Apple M3 Max
+Total Number of Cores: 16 (12 performance and 4 efficiency)
+Memory: 128 GB
+System Firmware Version: 10000.000.0
+OS Loader Version: 10000.000.0
+Serial Number (system): XXXXXXXXXX
+Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
+Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
+Activation Lock Status: Enabled
+"""
 
-        # Call the function
-        result = mac_device_capabilities()
+    # Call the function
+    result = mac_device_capabilities()
 
-        # Check the results
-        self.assertIsInstance(result, DeviceCapabilities)
-        self.assertEqual(result.model, "MacBook Pro")
-        self.assertEqual(result.chip, "Apple M3 Max")
-        self.assertEqual(result.memory, 131072)  # 16 GB in MB
-        self.assertEqual(str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS")
+    # Check the results
+    self.assertIsInstance(result, DeviceCapabilities)
+    self.assertEqual(result.model, "MacBook Pro")
+    self.assertEqual(result.chip, "Apple M3 Max")
+    self.assertEqual(result.memory, 131072)  # 16 GB in MB
+    self.assertEqual(
+      str(result),
+      "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
+    )
 
-    @patch('subprocess.check_output')
-    def test_mac_device_capabilities(self, mock_check_output):
-        # Mock the subprocess output
-        mock_check_output.return_value = b"""
+  @patch("subprocess.check_output")
+  def test_mac_device_capabilities(self, mock_check_output):
+    # Mock the subprocess output
+    mock_check_output.return_value = b"""
 Hardware:
 
-    Hardware Overview:
+Hardware Overview:
+
+Model Name: MacBook Air
+Model Identifier: Mac14,2
+Model Number: MLY33B/A
+Chip: Apple M2
+Total Number of Cores: 8 (4 performance and 4 efficiency)
+Memory: 8 GB
+System Firmware Version: 10000.00.0
+OS Loader Version: 10000.00.0
+Serial Number (system): XXXXXXXXXX
+Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
+Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
+Activation Lock Status: Disabled
+"""
 
-      Model Name: MacBook Air
-      Model Identifier: Mac14,2
-      Model Number: MLY33B/A
-      Chip: Apple M2
-      Total Number of Cores: 8 (4 performance and 4 efficiency)
-      Memory: 8 GB
-      System Firmware Version: 10000.00.0
-      OS Loader Version: 10000.00.0
-      Serial Number (system): XXXXXXXXXX
-      Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
-      Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
-      Activation Lock Status: Disabled
-        """
+    # Call the function
+    result = mac_device_capabilities()
 
-        # Call the function
-        result = mac_device_capabilities()
+    # Check the results
+    self.assertIsInstance(result, DeviceCapabilities)
+    self.assertEqual(result.model, "MacBook Air")
+    self.assertEqual(result.chip, "Apple M2")
+    self.assertEqual(result.memory, 8192)  # 8 GB in MB
 
-        # Check the results
-        self.assertIsInstance(result, DeviceCapabilities)
-        self.assertEqual(result.model, "MacBook Air")
-        self.assertEqual(result.chip, "Apple M2")
-        self.assertEqual(result.memory, 8192)  # 8 GB in MB
+  @unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
+  def test_mac_device_capabilities_real(self):
+    # Call the function without mocking
+    result = mac_device_capabilities()
 
-    @unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
-    def test_mac_device_capabilities_real(self):
-        # Call the function without mocking
-        result = mac_device_capabilities()
+    # Check the results
+    self.assertIsInstance(result, DeviceCapabilities)
+    self.assertEqual(result.model, "MacBook Pro")
+    self.assertEqual(result.chip, "Apple M3 Max")
+    self.assertEqual(result.memory, 131072)  # 128 GB in MB
+    self.assertEqual(result.flops, DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS))
+    self.assertEqual(
+      str(result),
+      "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
+    )
 
-        # Check the results
-        self.assertIsInstance(result, DeviceCapabilities)
-        self.assertEqual(result.model, "MacBook Pro")
-        self.assertEqual(result.chip, "Apple M3 Max")
-        self.assertEqual(result.memory, 131072)  # 128 GB in MB
-        self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
-        self.assertEqual(str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS")
 
-if __name__ == '__main__':
-    unittest.main()
+if __name__ == "__main__":
+  unittest.main()

+ 68 - 55
exo/topology/test_map_partitions.py

@@ -3,66 +3,79 @@ from typing import List
 from exo.topology.partitioning_strategy import Partition, map_partitions_to_shards
 from exo.inference.shard import Shard
 
+
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
-    def test_map_partitions_to_shards(self):
-        partitions = [
-            Partition('node1', 0.0, 0.42857),
-            Partition('node2', 0.42857, 0.71428),
-            Partition('node3', 0.71428, 0.99999),
-        ]
-        shards = map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [
-            Shard('model', 0, 12, 32),
-            Shard('model', 13, 21, 32),
-            Shard('model', 22, 31, 32),
-        ])
+  def test_map_partitions_to_shards(self):
+    partitions = [
+      Partition("node1", 0.0, 0.42857),
+      Partition("node2", 0.42857, 0.71428),
+      Partition("node3", 0.71428, 0.99999),
+    ]
+    shards = map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(
+      shards,
+      [
+        Shard("model", 0, 12, 32),
+        Shard("model", 13, 21, 32),
+        Shard("model", 22, 31, 32),
+      ],
+    )
 
-        partitions = [
-            Partition('node1', 0.0, 0.1),
-            Partition('node2', 0.1, 0.2),
-            Partition('node3', 0.2, 1.0),
-        ]
-        shards = map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [
-            Shard('model', 0, 2, 32),
-            Shard('model', 3, 5, 32),
-            Shard('model', 6, 31, 32),
-        ])
+    partitions = [
+      Partition("node1", 0.0, 0.1),
+      Partition("node2", 0.1, 0.2),
+      Partition("node3", 0.2, 1.0),
+    ]
+    shards = map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(
+      shards,
+      [
+        Shard("model", 0, 2, 32),
+        Shard("model", 3, 5, 32),
+        Shard("model", 6, 31, 32),
+      ],
+    )
 
-        partitions = [
-            Partition('node1', 0.0, 1.0),
-        ]
-        shards = map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [
-            Shard('model', 0, 31, 32),
-        ])
+    partitions = [
+      Partition("node1", 0.0, 1.0),
+    ]
+    shards = map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(
+      shards,
+      [
+        Shard("model", 0, 31, 32),
+      ],
+    )
 
-        partitions = []
-        shards = map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [])
+    partitions = []
+    shards = map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(shards, [])
 
-    def test_broken_map_partitions_to_shards(self):
-        # this was an old broken implementation that sometimes had rounding errors!
-        def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
-            shards = []
-            for i, partition in enumerate(partitions):
-                start_layer = int(partition.start * num_layers)
-                end_layer = int(partition.end * num_layers) - 1
-                shards.append(Shard(model_id, start_layer, end_layer, num_layers))
-            return shards
+  def test_broken_map_partitions_to_shards(self):
+    # this was an old broken implementation that sometimes had rounding errors!
+    def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
+      shards = []
+      for i, partition in enumerate(partitions):
+        start_layer = int(partition.start * num_layers)
+        end_layer = int(partition.end * num_layers) - 1
+        shards.append(Shard(model_id, start_layer, end_layer, num_layers))
+      return shards
 
-        partitions = [
-            Partition('node1', 0.0, 0.42857),
-            Partition('node2', 0.42857, 0.71428),
-            Partition('node3', 0.71428, 0.99999),
-        ]
-        shards = _broken_map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [
-            Shard('model', 0, 12, 32),
-            Shard('model', 13, 21, 32),
-            Shard('model', 22, 30, 32),
-        ])
+    partitions = [
+      Partition("node1", 0.0, 0.42857),
+      Partition("node2", 0.42857, 0.71428),
+      Partition("node3", 0.71428, 0.99999),
+    ]
+    shards = _broken_map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(
+      shards,
+      [
+        Shard("model", 0, 12, 32),
+        Shard("model", 13, 21, 32),
+        Shard("model", 22, 30, 32),
+      ],
+    )
 
-if __name__ == '__main__':
-    unittest.main()
 
+if __name__ == "__main__":
+  unittest.main()

+ 83 - 42
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -4,46 +4,87 @@ from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
 
+
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
-    def test_partition(self):
-        # triangle
-        # node1 -> node2 -> node3 -> node1
-        topology = Topology()
-        topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.add_edge('node1', 'node2')
-        topology.add_edge('node2', 'node3')
-        topology.add_edge('node3', 'node1')
-        topology.add_edge('node1', 'node3')
-
-        strategy = RingMemoryWeightedPartitioningStrategy()
-        partitions = strategy.partition(topology)
-
-        self.assertEqual(len(partitions), 3)
-        self.assertEqual(partitions, [
-            Partition('node3', 0.0, 0.6),
-            Partition('node1', 0.6, 0.9),
-            Partition('node2', 0.9, 1.0),
-        ])
-
-    def test_partition_rounding(self):
-        # triangle
-        # node1 -> node2 -> node3 -> node1
-        topology = Topology()
-        topology.update_node('node1', DeviceCapabilities(model="MacBook Pro", chip="test1", memory=128*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.update_node('node2', DeviceCapabilities(model="Mac Studio", chip="test2", memory=192*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.update_node('node3', DeviceCapabilities(model="MacBook Pro", chip="test3", memory=128*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-
-        strategy = RingMemoryWeightedPartitioningStrategy()
-        partitions = strategy.partition(topology)
-
-        self.assertEqual(len(partitions), 3)
-        self.assertEqual(partitions, [
-            Partition('node3', 0.0, 0.42857),
-            Partition('node1', 0.6, 0.9),
-            Partition('node2', 0.9, 1.0),
-        ])
-
-if __name__ == '__main__':
-    unittest.main()
+  def test_partition(self):
+    # triangle
+    # node1 -> node2 -> node3 -> node1
+    topology = Topology()
+    topology.update_node(
+      "node1",
+      DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+    topology.update_node(
+      "node2",
+      DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+    topology.update_node(
+      "node3",
+      DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+    topology.add_edge("node1", "node2")
+    topology.add_edge("node2", "node3")
+    topology.add_edge("node3", "node1")
+    topology.add_edge("node1", "node3")
+
+    strategy = RingMemoryWeightedPartitioningStrategy()
+    partitions = strategy.partition(topology)
+
+    self.assertEqual(len(partitions), 3)
+    self.assertEqual(
+      partitions,
+      [
+        Partition("node3", 0.0, 0.6),
+        Partition("node1", 0.6, 0.9),
+        Partition("node2", 0.9, 1.0),
+      ],
+    )
+
+  def test_partition_rounding(self):
+    # triangle
+    # node1 -> node2 -> node3 -> node1
+    topology = Topology()
+    topology.update_node(
+      "node1",
+      DeviceCapabilities(
+        model="MacBook Pro",
+        chip="test1",
+        memory=128 * 1024 * 1024 * 1024,
+        flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+      ),
+    )
+    topology.update_node(
+      "node2",
+      DeviceCapabilities(
+        model="Mac Studio",
+        chip="test2",
+        memory=192 * 1024 * 1024 * 1024,
+        flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+      ),
+    )
+    topology.update_node(
+      "node3",
+      DeviceCapabilities(
+        model="MacBook Pro",
+        chip="test3",
+        memory=128 * 1024 * 1024 * 1024,
+        flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+      ),
+    )
+
+    strategy = RingMemoryWeightedPartitioningStrategy()
+    partitions = strategy.partition(topology)
+
+    self.assertEqual(len(partitions), 3)
+    self.assertEqual(
+      partitions,
+      [
+        Partition("node3", 0.0, 0.42857),
+        Partition("node1", 0.6, 0.9),
+        Partition("node2", 0.9, 1.0),
+      ],
+    )
+
+
+if __name__ == "__main__":
+  unittest.main()

+ 45 - 44
exo/topology/topology.py

@@ -1,48 +1,49 @@
 from .device_capabilities import DeviceCapabilities
 from typing import Dict, Set, Optional
 
+
 class Topology:
-    def __init__(self):
-        self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
-        self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
-        self.active_node_id: Optional[str] = None
-
-    def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
-        self.nodes[node_id] = device_capabilities
-
-    def get_node(self, node_id: str) -> DeviceCapabilities:
-        return self.nodes.get(node_id)
-
-    def all_nodes(self):
-        return self.nodes.items()
-
-    def add_edge(self, node1_id: str, node2_id: str):
-        if node1_id not in self.peer_graph:
-            self.peer_graph[node1_id] = set()
-        if node2_id not in self.peer_graph:
-            self.peer_graph[node2_id] = set()
-        self.peer_graph[node1_id].add(node2_id)
-        self.peer_graph[node2_id].add(node1_id)
-
-    def get_neighbors(self, node_id: str) -> Set[str]:
-        return self.peer_graph.get(node_id, set())
-
-    def all_edges(self):
-        edges = []
-        for node, neighbors in self.peer_graph.items():
-            for neighbor in neighbors:
-                if (neighbor, node) not in edges:  # Avoid duplicate edges
-                    edges.append((node, neighbor))
-        return edges
-
-    def merge(self, other: 'Topology'):
-        for node_id, capabilities in other.nodes.items():
-            self.update_node(node_id, capabilities)
-        for node_id, neighbors in other.peer_graph.items():
-            for neighbor in neighbors:
-                self.add_edge(node_id, neighbor)
-
-    def __str__(self):
-        nodes_str = ', '.join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
-        edges_str = ', '.join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
-        return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"
+  def __init__(self):
+    self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
+    self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
+    self.active_node_id: Optional[str] = None
+
+  def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
+    self.nodes[node_id] = device_capabilities
+
+  def get_node(self, node_id: str) -> DeviceCapabilities:
+    return self.nodes.get(node_id)
+
+  def all_nodes(self):
+    return self.nodes.items()
+
+  def add_edge(self, node1_id: str, node2_id: str):
+    if node1_id not in self.peer_graph:
+      self.peer_graph[node1_id] = set()
+    if node2_id not in self.peer_graph:
+      self.peer_graph[node2_id] = set()
+    self.peer_graph[node1_id].add(node2_id)
+    self.peer_graph[node2_id].add(node1_id)
+
+  def get_neighbors(self, node_id: str) -> Set[str]:
+    return self.peer_graph.get(node_id, set())
+
+  def all_edges(self):
+    edges = []
+    for node, neighbors in self.peer_graph.items():
+      for neighbor in neighbors:
+        if (neighbor, node) not in edges:  # Avoid duplicate edges
+          edges.append((node, neighbor))
+    return edges
+
+  def merge(self, other: "Topology"):
+    for node_id, capabilities in other.nodes.items():
+      self.update_node(node_id, capabilities)
+    for node_id, neighbors in other.peer_graph.items():
+      for neighbor in neighbors:
+        self.add_edge(node_id, neighbor)
+
+  def __str__(self):
+    nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
+    edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
+    return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"

+ 57 - 29
exo/viz/test_topology_viz.py

@@ -5,38 +5,66 @@ from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
 
+
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
-    async def asyncSetUp(self):
-        self.topology = Topology()
-        self.topology.update_node("node1", DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0,fp16=2.0,int8=4.0)))
-        self.topology.update_node("node2", DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0,fp16=4.0,int8=8.0)))
-        self.topology.update_node("node3", DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0,fp16=8.0,int8=16.0)))
-        self.topology.update_node("node4", DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0,fp16=16.0,int8=32.0)))
+  async def asyncSetUp(self):
+    self.topology = Topology()
+    self.topology.update_node(
+      "node1",
+      DeviceCapabilities(
+        model="ModelA", chip="ChipA", memory=8 * 1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)
+      ),
+    )
+    self.topology.update_node(
+      "node2",
+      DeviceCapabilities(
+        model="ModelB", chip="ChipB", memory=16 * 1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)
+      ),
+    )
+    self.topology.update_node(
+      "node3",
+      DeviceCapabilities(
+        model="ModelC", chip="ChipC", memory=32 * 1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)
+      ),
+    )
+    self.topology.update_node(
+      "node4",
+      DeviceCapabilities(
+        model="ModelD", chip="ChipD", memory=64 * 1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)
+      ),
+    )
+
+    self.top_viz = TopologyViz()
+    await asyncio.sleep(2)  # Simulate running for a short time
 
-        self.top_viz = TopologyViz()
-        await asyncio.sleep(2)  # Simulate running for a short time
+  async def test_layout_generation(self):
+    self.top_viz._generate_layout()
+    self.top_viz.refresh()
+    import time
 
-    async def test_layout_generation(self):
-        self.top_viz._generate_layout()
-        self.top_viz.refresh()
-        import time
-        time.sleep(2)
-        self.top_viz.update_visualization(self.topology, [
-            Partition("node1", 0, 0.2),
-            Partition("node4", 0.2, 0.4),
-            Partition("node2", 0.4, 0.8),
-            Partition("node3", 0.8, 0.9),
-        ])
-        time.sleep(2)
-        self.topology.active_node_id = "node3"
-        self.top_viz.update_visualization(self.topology, [
-            Partition("node1", 0, 0.3),
-            Partition("node5", 0.3, 0.5),
-            Partition("node2", 0.5, 0.7),
-            Partition("node4", 0.7, 0.9),
-        ])
-        time.sleep(2)
+    time.sleep(2)
+    self.top_viz.update_visualization(
+      self.topology,
+      [
+        Partition("node1", 0, 0.2),
+        Partition("node4", 0.2, 0.4),
+        Partition("node2", 0.4, 0.8),
+        Partition("node3", 0.8, 0.9),
+      ],
+    )
+    time.sleep(2)
+    self.topology.active_node_id = "node3"
+    self.top_viz.update_visualization(
+      self.topology,
+      [
+        Partition("node1", 0, 0.3),
+        Partition("node5", 0.3, 0.5),
+        Partition("node2", 0.5, 0.7),
+        Partition("node4", 0.7, 0.9),
+      ],
+    )
+    time.sleep(2)
 
 
 if __name__ == "__main__":
-    unittest.main()
+  unittest.main()

+ 160 - 154
exo/viz/topology_viz.py

@@ -11,158 +11,164 @@ from rich.style import Style
 from rich.color import Color
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 
+
 class TopologyViz:
-    def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
-        self.chatgpt_api_endpoint = chatgpt_api_endpoint
-        self.web_chat_url = web_chat_url
-        self.topology = Topology()
-        self.partitions: List[Partition] = []
-
-        self.console = Console()
-        self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
-        self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
-        self.live_panel.start()
-
-    def update_visualization(self, topology: Topology, partitions: List[Partition]):
-        self.topology = topology
-        self.partitions = partitions
-        self.refresh()
-
-    def refresh(self):
-        self.panel.renderable = self._generate_layout()
-        # Update the panel title with the number of nodes and partitions
-        node_count = len(self.topology.nodes)
-        self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
-        self.live_panel.update(self.panel, refresh=True)
-
-    def _generate_layout(self) -> str:
-        # Calculate visualization parameters
-        num_partitions = len(self.partitions)
-        radius_x = 30  # Increased horizontal radius
-        radius_y = 12  # Decreased vertical radius
-        center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
-
-        # Generate visualization
-        visualization = [[' ' for _ in range(100)] for _ in range(55)]  # Decreased height
-
-        # Add exo_text at the top in bright yellow
-        exo_lines = exo_text.split('\n')
-        yellow_style = Style(color="bright_yellow")
-        max_line_length = max(len(line) for line in exo_lines)
-        for i, line in enumerate(exo_lines):
-            centered_line = line.center(max_line_length)
-            start_x = (100 - max_line_length) // 2 + 15 # Center the text plus empirical adjustment of 15
-            colored_line = Text(centered_line, style=yellow_style)
-            for j, char in enumerate(str(colored_line)):
-                if 0 <= start_x + j < 100 and i < len(visualization):
-                    visualization[i][start_x + j] = char
-
-        # Display chatgpt_api_endpoint and web_chat_url if set
-        info_lines = []
-        if self.web_chat_url:
-            info_lines.append(f"Web Chat URL (tinychat): {self.web_chat_url}")
-        if self.chatgpt_api_endpoint:
-            info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}")
-
-        info_start_y = len(exo_lines) + 1
-        for i, line in enumerate(info_lines):
-            start_x = (100 - len(line)) // 2 + 15 # Center the info lines plus empirical adjustment of 15
-            for j, char in enumerate(line):
-                if 0 <= start_x + j < 100 and info_start_y + i < 55:
-                    visualization[info_start_y + i][start_x + j] = char
-
-        # Calculate total FLOPS and position on the bar
-        total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
-        bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
-
-        # Add GPU poor/rich bar
-        bar_width = 30  # Increased bar width
-        bar_start_x = (100 - bar_width) // 2  # Center the bar
-        bar_y = info_start_y + len(info_lines) + 1  # Position the bar below the info section with two cells of space
-        
-        # Create a gradient bar using emojis
-        gradient_bar = Text()
-        emojis = ['🟥', '🟧', '🟨', '🟩']  # Red, Orange, Yellow, Green
-        for i in range(bar_width):
-            emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
-            gradient_bar.append(emojis[emoji_index])
-
-        # Add the gradient bar to the visualization
-        visualization[bar_y][bar_start_x - 1] = '['
-        visualization[bar_y][bar_start_x + bar_width] = ']'
-        for i, segment in enumerate(str(gradient_bar)):
-            visualization[bar_y][bar_start_x + i] = segment
-        
-        # Add labels
-        visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = 'GPU poor'
-        visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = 'GPU rich'
-        
-        # Add position indicator and FLOPS value
-        pos_x = bar_start_x + int(bar_pos * bar_width)
-        flops_str = f"{total_flops:.2f} TFLOPS"
-        visualization[bar_y - 1][pos_x] = '▼'
-        visualization[bar_y + 1][pos_x - len(flops_str)//2:pos_x + len(flops_str)//2 + len(flops_str)%2] = flops_str
-        visualization[bar_y + 2][pos_x] = '▲'
-
-        for i, partition in enumerate(self.partitions):
-            device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
-
-            angle = 2 * math.pi * i / num_partitions
-            x = int(center_x + radius_x * math.cos(angle))
-            y = int(center_y + radius_y * math.sin(angle))
-
-            # Place node with different color for active node
-            if partition.node_id == self.topology.active_node_id:
-                visualization[y][x] = '🔴'  # Red circle for active node
-            else:
-                visualization[y][x] = '🔵'  # Blue circle for inactive nodes
-
-            # Place node info (model, memory, TFLOPS, partition) on three lines
-            node_info = [
-                f"{device_capabilities.model} {device_capabilities.memory // 1024}GB",
-                f"{device_capabilities.flops.fp16}TFLOPS",
-                f"[{partition.start:.2f}-{partition.end:.2f}]"
-            ]
-
-            # Calculate info position based on angle
-            info_distance_x = radius_x + 6  # Increased horizontal distance
-            info_distance_y = radius_y + 3  # Decreased vertical distance
-            info_x = int(center_x + info_distance_x * math.cos(angle))
-            info_y = int(center_y + info_distance_y * math.sin(angle))
-
-            # Adjust text position to avoid overwriting the node icon and prevent cutoff
-            if info_x < x:  # Text is to the left of the node
-                info_x = max(0, x - len(max(node_info, key=len)) - 1)
-            elif info_x > x:  # Text is to the right of the node
-                info_x = min(99 - len(max(node_info, key=len)), info_x)
-            
-            # Adjust for top and bottom nodes
-            if 5*math.pi/4 < angle < 7*math.pi/4:  # Node is near the top
-                info_x += 4  # Shift text slightly to the right
-            elif math.pi/4 < angle < 3*math.pi/4:  # Node is near the bottom
-                info_x += 3  # Shift text slightly to the right
-                info_y -= 2  # Move text up by two cells
-
-            for j, line in enumerate(node_info):
-                for k, char in enumerate(line):
-                    if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
-                        # Ensure we're not overwriting the node icon
-                        if info_y + j != y or info_x + k != x:
-                            visualization[info_y + j][info_x + k] = char
-
-            # Draw line to next node
-            next_i = (i + 1) % num_partitions
-            next_angle = 2 * math.pi * next_i / num_partitions
-            next_x = int(center_x + radius_x * math.cos(next_angle))
-            next_y = int(center_y + radius_y * math.sin(next_angle))
-
-            # Simple line drawing
-            steps = max(abs(next_x - x), abs(next_y - y))
-            for step in range(1, steps):
-                line_x = int(x + (next_x - x) * step / steps)
-                line_y = int(y + (next_y - y) * step / steps)
-                if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
-                    visualization[line_y][line_x] = '-'
-
-        # Convert to string
-        return '\n'.join(''.join(str(char) for char in row) for row in visualization)
+  def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
+    self.chatgpt_api_endpoint = chatgpt_api_endpoint
+    self.web_chat_url = web_chat_url
+    self.topology = Topology()
+    self.partitions: List[Partition] = []
+
+    self.console = Console()
+    self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
+    self.live_panel.start()
+
+  def update_visualization(self, topology: Topology, partitions: List[Partition]):
+    self.topology = topology
+    self.partitions = partitions
+    self.refresh()
+
+  def refresh(self):
+    self.panel.renderable = self._generate_layout()
+    # Update the panel title with the number of nodes and partitions
+    node_count = len(self.topology.nodes)
+    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
+    self.live_panel.update(self.panel, refresh=True)
+
+  def _generate_layout(self) -> str:
+    # Calculate visualization parameters
+    num_partitions = len(self.partitions)
+    radius_x = 30  # Increased horizontal radius
+    radius_y = 12  # Decreased vertical radius
+    center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
+
+    # Generate visualization
+    visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
+
+    # Add exo_text at the top in bright yellow
+    exo_lines = exo_text.split("\n")
+    yellow_style = Style(color="bright_yellow")
+    max_line_length = max(len(line) for line in exo_lines)
+    for i, line in enumerate(exo_lines):
+      centered_line = line.center(max_line_length)
+      start_x = (100 - max_line_length) // 2 + 15  # Center the text plus empirical adjustment of 15
+      colored_line = Text(centered_line, style=yellow_style)
+      for j, char in enumerate(str(colored_line)):
+        if 0 <= start_x + j < 100 and i < len(visualization):
+          visualization[i][start_x + j] = char
+
+    # Display chatgpt_api_endpoint and web_chat_url if set
+    info_lines = []
+    if self.web_chat_url:
+      info_lines.append(f"Web Chat URL (tinychat): {self.web_chat_url}")
+    if self.chatgpt_api_endpoint:
+      info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}")
+
+    info_start_y = len(exo_lines) + 1
+    for i, line in enumerate(info_lines):
+      start_x = (100 - len(line)) // 2 + 15  # Center the info lines plus empirical adjustment of 15
+      for j, char in enumerate(line):
+        if 0 <= start_x + j < 100 and info_start_y + i < 55:
+          visualization[info_start_y + i][start_x + j] = char
+
+    # Calculate total FLOPS and position on the bar
+    total_flops = sum(
+      self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16
+      for partition in self.partitions
+    )
+    bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
+
+    # Add GPU poor/rich bar
+    bar_width = 30  # Increased bar width
+    bar_start_x = (100 - bar_width) // 2  # Center the bar
+    bar_y = info_start_y + len(info_lines) + 1  # Position the bar below the info section with two cells of space
+
+    # Create a gradient bar using emojis
+    gradient_bar = Text()
+    emojis = ["🟥", "🟧", "🟨", "🟩"]  # Red, Orange, Yellow, Green
+    for i in range(bar_width):
+      emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
+      gradient_bar.append(emojis[emoji_index])
+
+    # Add the gradient bar to the visualization
+    visualization[bar_y][bar_start_x - 1] = "["
+    visualization[bar_y][bar_start_x + bar_width] = "]"
+    for i, segment in enumerate(str(gradient_bar)):
+      visualization[bar_y][bar_start_x + i] = segment
+
+    # Add labels
+    visualization[bar_y - 1][bar_start_x - 10 : bar_start_x - 3] = "GPU poor"
+    visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2 : bar_start_x + bar_width * 2 + 11] = "GPU rich"
+
+    # Add position indicator and FLOPS value
+    pos_x = bar_start_x + int(bar_pos * bar_width)
+    flops_str = f"{total_flops:.2f} TFLOPS"
+    visualization[bar_y - 1][pos_x] = "▼"
+    visualization[bar_y + 1][
+      pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2
+    ] = flops_str
+    visualization[bar_y + 2][pos_x] = "▲"
+
+    for i, partition in enumerate(self.partitions):
+      device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
+
+      angle = 2 * math.pi * i / num_partitions
+      x = int(center_x + radius_x * math.cos(angle))
+      y = int(center_y + radius_y * math.sin(angle))
+
+      # Place node with different color for active node
+      if partition.node_id == self.topology.active_node_id:
+        visualization[y][x] = "🔴"  # Red circle for active node
+      else:
+        visualization[y][x] = "🔵"  # Blue circle for inactive nodes
+
+      # Place node info (model, memory, TFLOPS, partition) on three lines
+      node_info = [
+        f"{device_capabilities.model} {device_capabilities.memory // 1024}GB",
+        f"{device_capabilities.flops.fp16}TFLOPS",
+        f"[{partition.start:.2f}-{partition.end:.2f}]",
+      ]
+
+      # Calculate info position based on angle
+      info_distance_x = radius_x + 6  # Increased horizontal distance
+      info_distance_y = radius_y + 3  # Decreased vertical distance
+      info_x = int(center_x + info_distance_x * math.cos(angle))
+      info_y = int(center_y + info_distance_y * math.sin(angle))
+
+      # Adjust text position to avoid overwriting the node icon and prevent cutoff
+      if info_x < x:  # Text is to the left of the node
+        info_x = max(0, x - len(max(node_info, key=len)) - 1)
+      elif info_x > x:  # Text is to the right of the node
+        info_x = min(99 - len(max(node_info, key=len)), info_x)
+
+      # Adjust for top and bottom nodes
+      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:  # Node is near the top
+        info_x += 4  # Shift text slightly to the right
+      elif math.pi / 4 < angle < 3 * math.pi / 4:  # Node is near the bottom
+        info_x += 3  # Shift text slightly to the right
+        info_y -= 2  # Move text up by two cells
+
+      for j, line in enumerate(node_info):
+        for k, char in enumerate(line):
+          if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
+            # Ensure we're not overwriting the node icon
+            if info_y + j != y or info_x + k != x:
+              visualization[info_y + j][info_x + k] = char
+
+      # Draw line to next node
+      next_i = (i + 1) % num_partitions
+      next_angle = 2 * math.pi * next_i / num_partitions
+      next_x = int(center_x + radius_x * math.cos(next_angle))
+      next_y = int(center_y + radius_y * math.sin(next_angle))
+
+      # Simple line drawing
+      steps = max(abs(next_x - x), abs(next_y - y))
+      for step in range(1, steps):
+        line_x = int(x + (next_x - x) * step / steps)
+        line_y = int(y + (next_y - y) * step / steps)
+        if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
+          visualization[line_y][line_x] = "-"
+
+    # Convert to string
+    return "\n".join("".join(str(char) for char in row) for row in visualization)

+ 105 - 0
format.py

@@ -0,0 +1,105 @@
+#!/usr/bin/env python
+import re
+import subprocess
+import sys
+import os
+import fnmatch
+
+DEBUG_PATTERN = re.compile(r'^(\s*)(if\s+DEBUG\s*>=?\s*\d+\s*:.+)$', re.MULTILINE)
+PLACEHOLDER = "###DEBUG_PLACEHOLDER###"
+
+# Add ignore patterns here
+IGNORE_PATTERNS = [
+    '.venv/*',
+    'setup.py',
+    '*helpers.py',
+    '*node_service_pb2.py',
+    '*node_service_pb2_grpc.py',
+]
+
+def should_ignore(file_path):
+    for pattern in IGNORE_PATTERNS:
+        if fnmatch.fnmatch(file_path, pattern):
+            return True
+    return False
+
+def preserve_debug_lines(content):
+    def replace(match):
+        indent, line = match.groups()
+        return f"{indent}{PLACEHOLDER}{line.strip()}"
+    return DEBUG_PATTERN.sub(replace, content)
+
+def restore_debug_lines(content):
+    return re.sub(f"^(\\s*){PLACEHOLDER}(.+)$", r"\1\2", content, flags=re.MULTILINE)
+
+def adjust_indentation(content):
+    lines = content.split('\n')
+    adjusted_lines = []
+    for line in lines:
+        if line.strip() and not line.startswith(PLACEHOLDER):
+            indent = len(line) - len(line.lstrip())
+            new_indent = ' ' * (indent // 2)
+            adjusted_lines.append(new_indent + line.lstrip())
+        else:
+            adjusted_lines.append(line)
+    return '\n'.join(adjusted_lines)
+
+def process_file(file_path, process_func):
+    with open(file_path, 'r') as file:
+        content = file.read()
+    
+    modified_content = process_func(content)
+    
+    if content != modified_content:
+        with open(file_path, 'w') as file:
+            file.write(modified_content)
+
+def run_black(target):
+    # Convert ignore patterns to Black's --extend-exclude format
+    exclude_patterns = '|'.join(f'({pattern.replace("*", ".*")})' for pattern in IGNORE_PATTERNS)
+    command = [
+        "black",
+        "--line-length", "120",
+        "--extend-exclude", exclude_patterns,
+        target
+    ]
+    subprocess.run(command, check=True)
+
+def format_files(target):
+    if os.path.isfile(target):
+        files = [target] if not should_ignore(target) else []
+    elif os.path.isdir(target):
+        files = []
+        for root, _, filenames in os.walk(target):
+            for filename in filenames:
+                if filename.endswith('.py'):
+                    file_path = os.path.join(root, filename)
+                    if not should_ignore(file_path):
+                        files.append(file_path)
+    else:
+        print(f"Error: {target} is not a valid file or directory")
+        return
+
+    # Preserve debug lines
+    for file in files:
+        process_file(file, preserve_debug_lines)
+
+    # Run Black
+    run_black(target)
+
+    # Adjust indentation and restore debug lines
+    for file in files:
+        process_file(file, adjust_indentation)
+        process_file(file, restore_debug_lines)
+
+def main():
+    if len(sys.argv) < 2:
+        print("Usage: python format.py <directory_or_file>")
+        sys.exit(1)
+
+    target = sys.argv[1]
+    format_files(target)
+    print("Formatting completed.")
+
+if __name__ == "__main__":
+    main()

+ 5 - 0
lint.sh

@@ -0,0 +1,5 @@
+#!/bin/bash
+
+pip3 install -e '.[linting]'
+python3 -m ruff check .
+python3 -m pylint .

+ 17 - 0
pyproject.toml

@@ -0,0 +1,17 @@
+[tool.black]
+line-length = 120
+indent-size = 2
+skip-string-normalization = true
+
+[tool.isort]
+profile = "black"
+line_length = 120
+indent = "  "
+
+[tool.pylint.format]
+indent-string = '  '
+max-line-length = 120
+
+[tool.autopep8]
+max_line_length = 120
+indent_size = 2

+ 43 - 0
ruff.toml

@@ -0,0 +1,43 @@
+indent-width = 2
+preview = true
+target-version = "py312"
+
+lint.select = [
+  "F",  # Pyflakes
+  "W6",
+  "E71",
+  "E72",
+  "E112",   # no-indented-block
+  "E113",   # unexpected-indentation
+  # "E124",
+  "E203",   # whitespace-before-punctuation
+  "E272",   # multiple-spaces-before-keyword
+  "E303",   # too-many-blank-lines
+  "E304",   # blank-line-after-decorator
+  "E501",   # line-too-long
+  # "E502",
+  "E702",   # multiple-statements-on-one-line-semicolon
+  "E703",   # useless-semicolon
+  "E731",   # lambda-assignment
+  "W191",   # tab-indentation
+  "W291",   # trailing-whitespace
+  "W293",   # blank-line-with-whitespace
+  "UP039",  # unnecessary-class-parentheses
+  "C416",   # unnecessary-comprehension
+  "RET506", # superfluous-else-raise
+  "RET507", # superfluous-else-continue
+  "A",      # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
+  "SIM105", # suppressible-exception
+  "FURB110",# if-exp-instead-of-or-operator
+]
+
+line-length = 200
+
+exclude = [
+  "docs/",
+  "examples/",
+  "extra/",
+  "exo/networking/grpc/node_service_pb2.py",
+  "exo/networking/grpc/node_service_pb2_grpc.py",
+  "exo/helpers.py",
+]

+ 10 - 1
setup.py

@@ -1,6 +1,7 @@
-from setuptools import setup, find_packages
 import sys
 
+from setuptools import find_packages, setup
+
 # Base requirements for all platforms
 install_requires = [
     "aiohttp==3.9.5",
@@ -35,10 +36,18 @@ if sys.platform.startswith("darwin"):
         ]
     )
 
+extras_require = {
+    "linting": [
+        "pylint==3.2.6",
+        "ruff==0.5.5",
+        "mypy==1.11.0",
+    ]
+}
 
 setup(
     name="exo",
     version="0.0.1",
     packages=find_packages(),
     install_requires=install_requires,
+    extras_require=extras_require,
 )