浏览代码

Merge branch 'main' into HEAD

Alex Cheema 1 年之前
父节点
当前提交
7d5eed1111
共有 49 个文件被更改,包括 3585 次插入2476 次删除
  1. 472 0
      .pylintrc
  2. 1 1
      exo/__init__.py
  3. 1 1
      exo/api/__init__.py
  4. 276 250
      exo/api/chatgpt_api.py
  5. 110 102
      exo/helpers.py
  6. 44 21
      exo/inference/debug_inference_engine.py
  7. 7 6
      exo/inference/inference_engine.py
  8. 9 0
      exo/inference/mlx/models/base.py
  9. 130 0
      exo/inference/mlx/models/deepseek_v2.py
  10. 124 0
      exo/inference/mlx/models/llama.py
  11. 0 351
      exo/inference/mlx/models/sharded_llama.py
  12. 17 16
      exo/inference/mlx/sharded_inference_engine.py
  13. 52 55
      exo/inference/mlx/sharded_model.py
  14. 182 198
      exo/inference/mlx/sharded_utils.py
  15. 5 5
      exo/inference/mlx/test_sharded_llama.py
  16. 29 28
      exo/inference/mlx/test_sharded_model.py
  17. 18 17
      exo/inference/shard.py
  18. 35 12
      exo/inference/test_inference_engine.py
  19. 180 93
      exo/inference/tinygrad/inference.py
  20. 100 36
      exo/inference/tinygrad/models/llama.py
  21. 1 1
      exo/networking/__init__.py
  22. 10 9
      exo/networking/discovery.py
  23. 172 120
      exo/networking/grpc/grpc_discovery.py
  24. 94 81
      exo/networking/grpc/grpc_peer_handle.py
  25. 94 65
      exo/networking/grpc/grpc_server.py
  26. 14 13
      exo/networking/grpc/test_grpc_discovery.py
  27. 31 30
      exo/networking/peer_handle.py
  28. 7 6
      exo/networking/server.py
  29. 40 39
      exo/orchestration/node.py
  30. 369 266
      exo/orchestration/standard_node.py
  31. 49 48
      exo/orchestration/test_node.py
  32. 19 18
      exo/stats/metrics.py
  33. 33 30
      exo/test_callbacks.py
  34. 154 132
      exo/topology/device_capabilities.py
  35. 24 21
      exo/topology/partitioning_strategy.py
  36. 12 12
      exo/topology/ring_memory_weighted_partitioning_strategy.py
  37. 72 64
      exo/topology/test_device_capabilities.py
  38. 68 55
      exo/topology/test_map_partitions.py
  39. 83 42
      exo/topology/test_ring_memory_weighted_partitioning_strategy.py
  40. 45 44
      exo/topology/topology.py
  41. 49 29
      exo/viz/test_topology_viz.py
  42. 155 155
      exo/viz/topology_viz.py
  43. 110 0
      format.py
  44. 5 0
      lint.sh
  45. 12 3
      main.py
  46. 17 0
      pyproject.toml
  47. 43 0
      ruff.toml
  48. 10 1
      setup.py
  49. 1 0
      tinychat/examples/tinychat/index.html

+ 472 - 0
.pylintrc

@@ -0,0 +1,472 @@
+[MASTER]
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code
+extension-pkg-whitelist=scipy,cereal.messaging.messaging_pyx,PyQt5,av
+
+# Add files or directories to the blacklist. They should be base names, not
+# paths.
+ignore=CVS
+
+# Add files or directories matching the regex patterns to the blacklist. The
+# regex matches against base names, not paths.
+ignore-patterns=.*node_service_pb2.*
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint.
+jobs=4
+
+# List of plugins (as comma separated values of python modules names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Specify a configuration file.
+#rcfile=
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
+confidence=
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once).You can also use "--disable=all" to
+# disable everything first and then reenable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use"--disable=all --enable=classes
+# --disable=W"
+disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401
+# E1101 for function binding
+# W0221 for Function class
+# W0105 for comment strings
+# E0401 for missing imports
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member,use-a-generator, no-else-return
+
+
+[REPORTS]
+
+# Python expression which should return a note less than 10 (10 is the highest
+# note). You have access to the variables errors warning, statement which
+# respectively contain the number of errors / warnings messages and the total
+# number of statements analyzed. This is used by the global evaluation report
+# (RP0004).
+evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details
+#msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio).You can also give a reporter class, eg
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages
+reports=no
+
+# Activate the evaluation score.
+score=yes
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=optparse.Values,sys.exit
+
+
+[LOGGING]
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format
+logging-modules=logging
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it working
+# install python-enchant package.
+spelling-dict=
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to indicated private dictionary in
+# --spelling-private-dict-file option instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+      XXX,
+      TODO
+
+
+[SIMILARITIES]
+
+# Ignore comments when computing similarities.
+ignore-comments=yes
+
+# Ignore docstrings when computing similarities.
+ignore-docstrings=yes
+
+# Ignore imports when computing similarities.
+ignore-imports=no
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=capnp.* cereal.* pygame.* zmq.* setproctitle.* smbus2.* usb1.* serial.* cv2.* ft4222.* carla.*
+
+# Tells whether missing members accessed in mixin class should be ignored. A
+# mixin class is detected if its name ends with "mixin" (case insensitive).
+ignore-mixin-members=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis. It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=flask setproctitle usb1 flask.ext.socketio smbus2 usb1.*
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid to define new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+          _cb
+
+# A regular expression matching the name of dummy variables (i.e. expectedly
+# not used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored. Default to name
+# with leading underscore
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )?<?https?://\S+>?$
+
+# Number of spaces of indent required inside a hanging  or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
+# tab).
+indent-string='  '
+
+# Maximum number of characters on a single line.
+max-line-length=150
+
+# Maximum number of lines in a module
+max-module-lines=1000
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[BASIC]
+
+# Naming style matching correct argument names
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style
+#argument-rgx=
+
+# Naming style matching correct attribute names
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma
+bad-names=foo,
+          bar,
+          baz,
+          toto,
+          tutu,
+          tata
+
+# Naming style matching correct class attribute names
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style
+#class-attribute-rgx=
+
+# Naming style matching correct class names
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-style
+#class-rgx=
+
+# Naming style matching correct constant names
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma
+good-names=i,
+           j,
+           k,
+           ex,
+           Run,
+           _
+
+# Include a hint for the correct naming format with invalid-name
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style
+#inlinevar-rgx=
+
+# Naming style matching correct method names
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style
+#method-rgx=
+
+# Naming style matching correct module names
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+property-classes=abc.abstractproperty
+
+# Naming style matching correct variable names
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style
+#variable-rgx=
+
+
+[DESIGN]
+
+# Maximum number of arguments for function / method
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in a if statement
+max-bool-expr=5
+
+# Maximum number of branch for function / method body
+max-branches=12
+
+# Maximum number of locals for function / method body
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body
+max-returns=6
+
+# Maximum number of statements in function / method body
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[CLASSES]
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+                      __new__,
+                      setUp
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+                  _fields,
+                  _replace,
+                  _source,
+                  _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[IMPORTS]
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Deprecated modules which should not be used, separated by a comma
+deprecated-modules=regsub,
+                   TERMIOS,
+                   Bastion,
+                   rexec
+
+# Create a graph of external dependencies in the given file (report RP0402 must
+# not be disabled)
+ext-import-graph=
+
+# Create a graph of every (i.e. internal and external) dependencies in the
+# given file (report RP0402 must not be disabled)
+import-graph=
+
+# Create a graph of internal dependencies in the given file (report RP0402 must
+# not be disabled)
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+[STRING]
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=yes
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when being caught. Defaults to
+# "Exception"
+overgeneral-exceptions=builtins.Exception

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG, DEBUG_DISCOVERY, VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 1 - 1
exo/api/__init__.py

@@ -1 +1 @@
-from exo.api.chatgpt_api import ChatGPTAPI
+from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI

+ 276 - 250
exo/api/chatgpt_api.py

@@ -13,279 +13,305 @@ from exo.inference.shard import Shard
 from exo.orchestration import Node
 from exo.orchestration import Node
 
 
 shard_mappings = {
 shard_mappings = {
-    ### llama
-    "llama-3.1-8b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    },
-    "llama-3.1-70b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    },
-    "llama-3.1-405b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
-    },
-    "llama-3-8b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
-    },
-    "llama-3-70b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
-    },
-    ### mistral
-    "mistral-nemo": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
-    },
-    "mistral-large": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
-    },
+  ### llama
+  "llama-3.1-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3.1-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "llama-3.1-405b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
+  },
+  "llama-3-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
+  },
+  ### mistral
+  "mistral-nemo": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
+  },
+  "mistral-large": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
+  },
+  ### deepseek v2
+  "deepseek-coder-v2-lite": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),
+  },
 }
 }
 
 
+
 class Message:
 class Message:
-    def __init__(self, role: str, content: str):
-        self.role = role
-        self.content = content
+  def __init__(self, role: str, content: str):
+    self.role = role
+    self.content = content
+
 
 
 class ChatCompletionRequest:
 class ChatCompletionRequest:
-    def __init__(self, model: str, messages: List[Message], temperature: float):
-        self.model = model
-        self.messages = messages
-        self.temperature = temperature
+  def __init__(self, model: str, messages: List[Message], temperature: float):
+    self.model = model
+    self.messages = messages
+    self.temperature = temperature
+
 
 
 def resolve_tinygrad_tokenizer(model_id: str):
 def resolve_tinygrad_tokenizer(model_id: str):
-    if model_id == "llama3-8b-sfr":
-        return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-    elif model_id == "llama3-70b-sfr":
-        return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
-    else:
-        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
+  if model_id == "llama3-8b-sfr":
+    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
+  elif model_id == "llama3-70b-sfr":
+    return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
+  else:
+    raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
+
 
 
 async def resolve_tokenizer(model_id: str):
 async def resolve_tokenizer(model_id: str):
-    try:
-        if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
-        return AutoTokenizer.from_pretrained(model_id)
-    except:
-        import traceback
-        if DEBUG >= 2: print(traceback.format_exc())
-        if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
+  try:
+    if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
+    return AutoTokenizer.from_pretrained(model_id)
+  except Exception as e:
+    if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer. Error: {e}")
+    import traceback
 
 
-    try:
-        if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
-        return resolve_tinygrad_tokenizer(model_id)
-    except:
-        import traceback
-        if DEBUG >= 2: print(traceback.format_exc())
-        if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
+    if DEBUG >= 2: print(traceback.format_exc())
+
+  try:
+    if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
+    return resolve_tinygrad_tokenizer(model_id)
+  except Exception as e:
+    if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer. Error: {e}")
+    import traceback
+
+    if DEBUG >= 2: print(traceback.format_exc())
+
+  if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
+  from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
+
+  return load_tokenizer(await get_model_path(model_id))
 
 
-    if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
-    from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-    return load_tokenizer(await get_model_path(model_id))
 
 
 def generate_completion(
 def generate_completion(
-        chat_request: ChatCompletionRequest,
-        tokenizer,
-        prompt: str,
-        request_id: str,
-        tokens: List[int],
-        stream: bool,
-        finish_reason: Union[Literal["length", "stop"], None],
-        object_type: Literal["chat.completion", "text_completion"]
-    ) -> dict:
-    completion = {
-        "id": f"chatcmpl-{request_id}",
-        "object": object_type,
-        "created": int(time.time()),
-        "model": chat_request.model,
-        "system_fingerprint": f"exo_{VERSION}",
-        "choices": [
-            {
-                "index": 0,
-                "message": {
-                    "role": "assistant",
-                    "content": tokenizer.decode(tokens)
-                },
-                "logprobs": None,
-                "finish_reason": finish_reason,
-            }
-        ]
+  chat_request: ChatCompletionRequest,
+  tokenizer,
+  prompt: str,
+  request_id: str,
+  tokens: List[int],
+  stream: bool,
+  finish_reason: Union[Literal["length", "stop"], None],
+  object_type: Literal["chat.completion", "text_completion"],
+) -> dict:
+  completion = {
+    "id": f"chatcmpl-{request_id}",
+    "object": object_type,
+    "created": int(time.time()),
+    "model": chat_request.model,
+    "system_fingerprint": f"exo_{VERSION}",
+    "choices": [
+      {
+        "index": 0,
+        "message": {"role": "assistant", "content": tokenizer.decode(tokens)},
+        "logprobs": None,
+        "finish_reason": finish_reason,
+      }
+    ],
+  }
+
+  if not stream:
+    completion["usage"] = {
+      "prompt_tokens": len(tokenizer.encode(prompt)),
+      "completion_tokens": len(tokens),
+      "total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
     }
     }
 
 
-    if not stream:
-        completion["usage"] = {
-            "prompt_tokens": len(tokenizer.encode(prompt)),
-            "completion_tokens": len(tokens),
-            "total_tokens": len(tokenizer.encode(prompt)) + len(tokens)
-        }
+  choice = completion["choices"][0]
+  if object_type.startswith("chat.completion"):
+    key_name = "delta" if stream else "message"
+    choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
+  elif object_type == "text_completion":
+    choice["text"] = tokenizer.decode(tokens)
+  else:
+    ValueError(f"Unsupported response type: {object_type}")
 
 
-    choice = completion["choices"][0]
-    if object_type.startswith("chat.completion"):
-        key_name = "delta" if stream else "message"
-        choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
-    elif object_type == "text_completion":
-        choice['text'] = tokenizer.decode(tokens)
-    else:
-        ValueError(f"Unsupported response type: {object_type}")
+  return completion
 
 
-    return completion
 
 
 def build_prompt(tokenizer, messages: List[Message]):
 def build_prompt(tokenizer, messages: List[Message]):
-    return tokenizer.apply_chat_template(
-        messages, tokenize=False, add_generation_prompt=True
-    )
+  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
 
 
 def parse_message(data: dict):
 def parse_message(data: dict):
-    if 'role' not in data or 'content' not in data:
-        raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
-    return Message(data['role'], data['content'])
+  if "role" not in data or "content" not in data:
+    raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
+  return Message(data["role"], data["content"])
+
 
 
 def parse_chat_request(data: dict):
 def parse_chat_request(data: dict):
-    return ChatCompletionRequest(
-        data.get('model', 'llama-3.1-8b'),
-        [parse_message(msg) for msg in data['messages']],
-        data.get('temperature', 0.0)
-    )
+  return ChatCompletionRequest(
+    data.get("model", "llama-3.1-8b"),
+    [parse_message(msg) for msg in data["messages"]],
+    data.get("temperature", 0.0),
+  )
+
 
 
 class ChatGPTAPI:
 class ChatGPTAPI:
-    def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
-        self.node = node
-        self.inference_engine_classname = inference_engine_classname
-        self.response_timeout_secs = response_timeout_secs
-        self.app = web.Application()
-        self.prev_token_lens: Dict[str, int] = {}
-        self.stream_tasks: Dict[str, asyncio.Task] = {}
-        cors = aiohttp_cors.setup(self.app)
-        cors_options = aiohttp_cors.ResourceOptions(
-            allow_credentials=True,
-            expose_headers="*",
-            allow_headers="*",
-            allow_methods="*",
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
+    self.node = node
+    self.inference_engine_classname = inference_engine_classname
+    self.response_timeout_secs = response_timeout_secs
+    self.app = web.Application()
+    self.prev_token_lens: Dict[str, int] = {}
+    self.stream_tasks: Dict[str, asyncio.Task] = {}
+    cors = aiohttp_cors.setup(self.app)
+    cors_options = aiohttp_cors.ResourceOptions(
+      allow_credentials=True,
+      expose_headers="*",
+      allow_headers="*",
+      allow_methods="*",
+    )
+    cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
+    cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
+    self.static_dir = Path(__file__).parent.parent.parent / "tinychat/examples/tinychat"
+    self.app.router.add_get("/", self.handle_root)
+    self.app.router.add_static("/", self.static_dir, name="static")
+
+    # Add middleware to log every request
+    self.app.middlewares.append(self.log_request)
+
+  async def log_request(self, app, handler):
+    async def middleware(request):
+      if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
+      return await handler(request)
+
+    return middleware
+
+  async def handle_root(self, request):
+    print(f"Handling root request from {request.remote}")
+    return web.FileResponse(self.static_dir / "index.html")
+
+  async def handle_post_chat_token_encode(self, request):
+    data = await request.json()
+    shard = shard_mappings.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    messages = [parse_message(msg) for msg in data.get("messages", [])]
+    tokenizer = await resolve_tokenizer(shard.model_id)
+    return web.json_response({"length": len(build_prompt(tokenizer, messages))})
+
+  async def handle_post_chat_completions(self, request):
+    data = await request.json()
+    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    stream = data.get("stream", False)
+    chat_request = parse_chat_request(data)
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
+      chat_request.model = "llama-3.1-8b"
+    if not chat_request.model or chat_request.model not in shard_mappings:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
+      chat_request.model = "llama-3.1-8b"
+    shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
+    if not shard:
+      supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
+      return web.json_response(
+        {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
+        status=400,
+      )
+    request_id = str(uuid.uuid4())
+
+    tokenizer = await resolve_tokenizer(shard.model_id)
+    if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
+
+    prompt = build_prompt(tokenizer, chat_request.messages)
+    callback_id = f"chatgpt-api-wait-response-{request_id}"
+    callback = self.node.on_token.register(callback_id)
+
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+    try:
+      await self.node.process_prompt(shard, prompt, request_id=request_id)
+    except Exception as e:
+      if DEBUG >= 2:
+        import traceback
+
+        traceback.print_exc()
+      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+
+    try:
+      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
+
+      if stream:
+        response = web.StreamResponse(
+          status=200,
+          reason="OK",
+          headers={
+            "Content-Type": "application/json",
+            "Cache-Control": "no-cache",
+          },
+        )
+        await response.prepare(request)
+
+        async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
+          prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
+          self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
+          new_tokens = tokens[prev_last_tokens_len:]
+          finish_reason = None
+          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+          if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
+            new_tokens = new_tokens[:-1]
+            if is_finished:
+              finish_reason = "stop"
+          if is_finished and not finish_reason:
+            finish_reason = "length"
+
+          completion = generate_completion(
+            chat_request,
+            tokenizer,
+            prompt,
+            request_id,
+            new_tokens,
+            stream,
+            finish_reason,
+            "chat.completion",
+          )
+          if DEBUG >= 2: print(f"Streaming completion: {completion}")
+          await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+
+        def on_result(_request_id: str, tokens: List[int], is_finished: bool):
+          self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+
+          return _request_id == request_id and is_finished
+
+        _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
+        if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
+          if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
+          try:
+            await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
+          except asyncio.TimeoutError:
+            print("WARNING: Stream task timed out. This should not happen.")
+        await response.write_eof()
+        return response
+      else:
+        _, tokens, _ = await callback.wait(
+          lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
+          timeout=self.response_timeout_secs,
         )
         )
-        cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post_chat_completions), {
-            "*": cors_options
-        })
-        cors.add(self.app.router.add_post('/v1/chat/token/encode', self.handle_post_chat_token_encode), {
-            "*": cors_options
-        })
-        self.static_dir = Path(__file__).parent.parent.parent / 'tinychat/examples/tinychat'
-        self.app.router.add_get('/', self.handle_root)
-        self.app.router.add_static('/', self.static_dir, name='static')
-
-        # Add middleware to log every request
-        self.app.middlewares.append(self.log_request)
-
-    async def log_request(self, app, handler):
-        async def middleware(request):
-            if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
-            return await handler(request)
-        return middleware
-
-    async def handle_root(self, request):
-        print(f"Handling root request from {request.remote}")
-        return web.FileResponse(self.static_dir / 'index.html')
-
-    async def handle_post_chat_token_encode(self, request):
-        data = await request.json()
-        shard = shard_mappings.get(data.get('model', 'llama-3.1-8b'), {}).get(self.inference_engine_classname)
-        messages = [parse_message(msg) for msg in data.get('messages', [])]
-        tokenizer = await resolve_tokenizer(shard.model_id)
-        return web.json_response({'length': len(build_prompt(tokenizer, messages))})
-
-    async def handle_post_chat_completions(self, request):
-        data = await request.json()
-        if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
-        stream = data.get('stream', False)
-        chat_request = parse_chat_request(data)
-        if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-            chat_request.model = "llama-3.1-8b"
-        if not chat_request.model or chat_request.model not in shard_mappings:
-            if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
-            chat_request.model = "llama-3.1-8b"
-        shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
-        if not shard:
-            supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
-            return web.json_response({'detail': f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"}, status=400)
-        request_id = str(uuid.uuid4())
-
-        tokenizer = await resolve_tokenizer(shard.model_id)
-        if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
-
-        prompt = build_prompt(tokenizer, chat_request.messages)
-        callback_id = f"chatgpt-api-wait-response-{request_id}"
-        callback = self.node.on_token.register(callback_id)
-
-        if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
-        try:
-            await self.node.process_prompt(shard, prompt, request_id=request_id)
-        except Exception as e:
-            if DEBUG >= 2:
-                import traceback
-                traceback.print_exc()
-            return web.json_response({'detail': f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-
-        try:
-            if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
-
-            if stream:
-                response = web.StreamResponse(
-                    status=200,
-                    reason="OK",
-                    headers={
-                        "Content-Type": "application/json",
-                        "Cache-Control": "no-cache",
-                    }
-                )
-                await response.prepare(request)
-
-                async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
-                    prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
-                    self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
-                    new_tokens = tokens[prev_last_tokens_len:]
-                    finish_reason = None
-                    eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
-                    if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
-                        new_tokens = new_tokens[:-1]
-                        if is_finished:
-                            finish_reason = "stop"
-                    if is_finished and not finish_reason:
-                        finish_reason = "length"
-
-                    completion = generate_completion(chat_request, tokenizer, prompt, request_id, new_tokens, stream, finish_reason, "chat.completion")
-                    if DEBUG >= 2: print(f"Streaming completion: {completion}")
-                    await response.write(f"data: {json.dumps(completion)}\n\n".encode())
-                def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-                    self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
-
-                    return _request_id == request_id and is_finished
-                _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
-                if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
-                    if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
-                    try:
-                        await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
-                    except asyncio.TimeoutError:
-                        print("WARNING: Stream task timed out. This should not happen.")
-                await response.write_eof()
-                return response
-            else:
-                _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
-
-                finish_reason = "length"
-                eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
-                if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
-                if tokens[-1] == eos_token_id:
-                    tokens = tokens[:-1]
-                    finish_reason = "stop"
-
-                return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
-        except asyncio.TimeoutError:
-            return web.json_response({'detail': "Response generation timed out"}, status=408)
-        finally:
-            deregistered_callback = self.node.on_token.deregister(callback_id)
-            if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
-
-    async def run(self, host: str = "0.0.0.0", port: int = 8000):
-        runner = web.AppRunner(self.app)
-        await runner.setup()
-        site = web.TCPSite(runner, host, port)
-        await site.start()
-        if DEBUG >= 0:
-            print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
-            print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")
+
+        finish_reason = "length"
+        eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+        if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
+        if tokens[-1] == eos_token_id:
+          tokens = tokens[:-1]
+          finish_reason = "stop"
+
+        return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
+    except asyncio.TimeoutError:
+      return web.json_response({"detail": "Response generation timed out"}, status=408)
+    finally:
+      deregistered_callback = self.node.on_token.deregister(callback_id)
+      if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
+
+  async def run(self, host: str = "0.0.0.0", port: int = 8000):
+    runner = web.AppRunner(self.app)
+    await runner.setup()
+    site = web.TCPSite(runner, host, port)
+    await site.start()
+    if DEBUG >= 0:
+      print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
+      print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")

+ 110 - 102
exo/helpers.py

@@ -18,124 +18,132 @@ exo_text = r"""
     """
     """
 
 
 def get_system_info():
 def get_system_info():
-    if psutil.MACOS:
-        if platform.machine() == 'arm64':
-            return "Apple Silicon Mac"
-        elif platform.machine() in ['x86_64', 'i386']:
-            return "Intel Mac"
-        else:
-            return "Unknown Mac architecture"
-    elif psutil.LINUX:
-        return "Linux"
-    else:
-        return "Non-Mac, non-Linux system"
+  if psutil.MACOS:
+    if platform.machine() == "arm64":
+      return "Apple Silicon Mac"
+    if platform.machine() in ["x86_64", "i386"]:
+      return "Intel Mac"
+    return "Unknown Mac architecture"
+  if psutil.LINUX:
+    return "Linux"
+  return "Non-Mac, non-Linux system"
+
 
 
 def get_inference_engine(inference_engine_name):
 def get_inference_engine(inference_engine_name):
-    if inference_engine_name == "mlx":
-        from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
-        return MLXDynamicShardInferenceEngine()
-    elif inference_engine_name == "tinygrad":
-        from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
-        return TinygradDynamicShardInferenceEngine()
-    else:
-        raise ValueError(f"Inference engine {inference_engine_name} not supported")
-
-def find_available_port(host: str = '', min_port: int = 49152, max_port: int = 65535) -> int:
-    used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.exo_used_ports')
-
-    def read_used_ports():
-        if os.path.exists(used_ports_file):
-            with open(used_ports_file, 'r') as f:
-                return [int(line.strip()) for line in f if line.strip().isdigit()]
-        return []
-
-    def write_used_port(port, used_ports):
-        with open(used_ports_file, 'w') as f:
-            print(used_ports[-19:])
-            for p in used_ports[-19:] + [port]:
-                f.write(f"{p}\n")
-
-    used_ports = read_used_ports()
-    available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
-
-    while available_ports:
-        port = random.choice(list(available_ports))
-        if DEBUG >= 2: print(f"Trying to find available port {port=}")
-        try:
-            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-                s.bind((host, port))
-            write_used_port(port, used_ports)
-            return port
-        except socket.error:
-            available_ports.remove(port)
-
-    raise RuntimeError("No available ports in the specified range")
+  if inference_engine_name == "mlx":
+    from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
+
+    return MLXDynamicShardInferenceEngine()
+  elif inference_engine_name == "tinygrad":
+    from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
+
+    return TinygradDynamicShardInferenceEngine()
+  else:
+    raise ValueError(f"Inference engine {inference_engine_name} not supported")
+
+
+def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
+  used_ports_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".exo_used_ports")
+
+  def read_used_ports():
+    if os.path.exists(used_ports_file):
+      with open(used_ports_file, "r") as f:
+        return [int(line.strip()) for line in f if line.strip().isdigit()]
+    return []
+
+  def write_used_port(port, used_ports):
+    with open(used_ports_file, "w") as f:
+      print(used_ports[-19:])
+      for p in used_ports[-19:] + [port]:
+        f.write(f"{p}\n")
+
+  used_ports = read_used_ports()
+  available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
+
+  while available_ports:
+    port = random.choice(list(available_ports))
+    if DEBUG >= 2: print(f"Trying to find available port {port=}")
+    try:
+      with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        s.bind((host, port))
+      write_used_port(port, used_ports)
+      return port
+    except socket.error:
+      available_ports.remove(port)
+
+  raise RuntimeError("No available ports in the specified range")
+
 
 
 def print_exo():
 def print_exo():
-    print(exo_text)
+  print(exo_text)
+
 
 
 def print_yellow_exo():
 def print_yellow_exo():
-    yellow = "\033[93m"  # ANSI escape code for yellow
-    reset = "\033[0m"    # ANSI escape code to reset color
-    print(f"{yellow}{exo_text}{reset}")
+  yellow = "\033[93m"  # ANSI escape code for yellow
+  reset = "\033[0m"  # ANSI escape code to reset color
+  print(f"{yellow}{exo_text}{reset}")
+
 
 
 def terminal_link(uri, label=None):
 def terminal_link(uri, label=None):
-    if label is None: 
-        label = uri
-    parameters = ''
+  if label is None:
+    label = uri
+  parameters = ""
+
+  # OSC 8 ; params ; URI ST <name> OSC 8 ;; ST
+  escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\"
+
+  return escape_mask.format(parameters, uri, label)
 
 
-    # OSC 8 ; params ; URI ST <name> OSC 8 ;; ST 
-    escape_mask = '\033]8;{};{}\033\\{}\033]8;;\033\\'
 
 
-    return escape_mask.format(parameters, uri, label)
+T = TypeVar("T")
+K = TypeVar("K")
 
 
-T = TypeVar('T')
-K = TypeVar('K')
 
 
 class AsyncCallback(Generic[T]):
 class AsyncCallback(Generic[T]):
-    def __init__(self) -> None:
-        self.condition: asyncio.Condition = asyncio.Condition()
-        self.result: Optional[Tuple[T, ...]] = None
-        self.observers: list[Callable[..., None]] = []
-
-    async def wait(self,
-                   check_condition: Callable[..., bool],
-                   timeout: Optional[float] = None) -> Tuple[T, ...]:
-        async with self.condition:
-            await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
-            assert self.result is not None  # for type checking
-            return self.result
-
-    def on_next(self, callback: Callable[..., None]) -> None:
-        self.observers.append(callback)
-
-    def set(self, *args: T) -> None:
-        self.result = args
-        for observer in self.observers:
-            observer(*args)
-        asyncio.create_task(self.notify())
-
-    async def notify(self) -> None:
-        async with self.condition:
-            self.condition.notify_all()
+  def __init__(self) -> None:
+    self.condition: asyncio.Condition = asyncio.Condition()
+    self.result: Optional[Tuple[T, ...]] = None
+    self.observers: list[Callable[..., None]] = []
+
+  async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
+    async with self.condition:
+      await asyncio.wait_for(
+        self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout
+      )
+      assert self.result is not None  # for type checking
+      return self.result
+
+  def on_next(self, callback: Callable[..., None]) -> None:
+    self.observers.append(callback)
+
+  def set(self, *args: T) -> None:
+    self.result = args
+    for observer in self.observers:
+      observer(*args)
+    asyncio.create_task(self.notify())
+
+  async def notify(self) -> None:
+    async with self.condition:
+      self.condition.notify_all()
+
 
 
 class AsyncCallbackSystem(Generic[K, T]):
 class AsyncCallbackSystem(Generic[K, T]):
-    def __init__(self) -> None:
-        self.callbacks: Dict[K, AsyncCallback[T]] = {}
+  def __init__(self) -> None:
+    self.callbacks: Dict[K, AsyncCallback[T]] = {}
 
 
-    def register(self, name: K) -> AsyncCallback[T]:
-        if name not in self.callbacks:
-            self.callbacks[name] = AsyncCallback[T]()
-        return self.callbacks[name]
+  def register(self, name: K) -> AsyncCallback[T]:
+    if name not in self.callbacks:
+      self.callbacks[name] = AsyncCallback[T]()
+    return self.callbacks[name]
 
 
-    def deregister(self, name: K) -> None:
-        if name in self.callbacks:
-            del self.callbacks[name]
+  def deregister(self, name: K) -> None:
+    if name in self.callbacks:
+      del self.callbacks[name]
 
 
-    def trigger(self, name: K, *args: T) -> None:
-        if name in self.callbacks:
-            self.callbacks[name].set(*args)
+  def trigger(self, name: K, *args: T) -> None:
+    if name in self.callbacks:
+      self.callbacks[name].set(*args)
 
 
-    def trigger_all(self, *args: T) -> None:
-        for callback in self.callbacks.values():
-            callback.set(*args)
+  def trigger_all(self, *args: T) -> None:
+    for callback in self.callbacks.values():
+      callback.set(*args)

+ 44 - 21
exo/inference/debug_inference_engine.py

@@ -1,38 +1,61 @@
-from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import asyncio
 import asyncio
 import numpy as np
 import numpy as np
 
 
+
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
-    from exo.inference.tinygrad.inference import Tokenizer
-    from pathlib import Path
-    _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
+  from exo.inference.tinygrad.inference import Tokenizer
+  from pathlib import Path
+
+  _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
 
 
-    prompt = "In a single word only, what is the last name of the president of the United States? "
-    resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-    next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
+  prompt = "In a single word only, what is the last name of the president of the United States? "
+  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+    "A",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
+    input_data=resp_full,
+    inference_state=inference_state_full,
+  )
 
 
-    resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-    resp2, inference_state_2, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state_1)
-    resp3, inference_state_3, _ = await inference_engine_1.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=resp2, inference_state=inference_state_2)
-    resp4, inference_state_4, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, inference_state=inference_state_3)
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp1,
+    inference_state=inference_state_1,
+  )
+  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
+    input_data=resp2,
+    inference_state=inference_state_2,
+  )
+  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp3,
+    inference_state=inference_state_3,
+  )
 
 
-    print(f"{resp2=}")
-    print(f"full: {_tokenizer.decode(resp_full)}")
-    print(f"next full: {_tokenizer.decode(next_resp_full)}")
-    print(f"resp2: {_tokenizer.decode(resp2)}")
-    print(f"{resp4=}")
-    print(f"resp4: {_tokenizer.decode(resp4)}")
+  print(f"{resp2=}")
+  print(f"full: {_tokenizer.decode(resp_full)}")
+  print(f"next full: {_tokenizer.decode(next_resp_full)}")
+  print(f"resp2: {_tokenizer.decode(resp2)}")
+  print(f"{resp4=}")
+  print(f"resp4: {_tokenizer.decode(resp4)}")
 
 
-    assert np.array_equal(resp_full, resp2)
-    assert np.array_equal(next_resp_full, resp4)
+  assert np.array_equal(resp_full, resp2)
+  assert np.array_equal(next_resp_full, resp4)
 
 
 
 
-asyncio.run(test_inference_engine(
+asyncio.run(
+  test_inference_engine(
     TinygradDynamicShardInferenceEngine(),
     TinygradDynamicShardInferenceEngine(),
     TinygradDynamicShardInferenceEngine(),
     TinygradDynamicShardInferenceEngine(),
     "llama3-8b-sfr",
     "llama3-8b-sfr",
-))
+  )
+)

+ 7 - 6
exo/inference/inference_engine.py

@@ -4,11 +4,12 @@ from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from .shard import Shard
 from .shard import Shard
 
 
+
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
-    @abstractmethod
-    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-        pass
+  @abstractmethod
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+    pass
 
 
-    @abstractmethod
-    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-        pass
+  @abstractmethod
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+    pass

+ 9 - 0
exo/inference/mlx/models/base.py

@@ -0,0 +1,9 @@
+from typing import Optional
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.base import KVCache
+
+
+class IdentityBlock(nn.Module):
+  def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
+    return x

+ 130 - 0
exo/inference/mlx/models/deepseek_v2.py

@@ -0,0 +1,130 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import KVCache
+from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
+from .base import IdentityBlock
+from ...shard import Shard
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class DeepseekV2Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.num_hidden_layers = config.num_hidden_layers
+    self.vocab_size = config.vocab_size
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(DeepseekV2DecoderLayer(config, i))
+      else:
+        self.layers.append(IdentityBlock())
+
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+  def __call__(
+    self,
+    x: mx.array,
+    cache: Optional[KVCache] = None,
+  ) -> mx.array:
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(x)
+    else:
+      h = x
+
+    mask = None
+    T = h.shape[1]
+    if T > 1:
+      mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, config: ModelArgs):
+    super().__init__()
+    self.args = config
+    self.model_type = config.model_type
+    self.model = DeepseekV2Model(config)
+    if self.args.shard.is_last_layer():
+      self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache: Optional[KVCache] = None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      return self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
+        shard_state_dict[key] = value
+
+    for l in range(self.args.num_hidden_layers):
+      prefix = f"model.layers.{l}"
+      for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
+        for k in ["weight", "scales", "biases"]:
+          if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
+            to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
+            shard_state_dict[
+              f"{prefix}.mlp.switch_mlp.{
+       m}.{k}"
+            ] = mx.stack(to_join)
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return (
+      self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
+      self.args.v_head_dim,
+    )
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 124 - 0
exo/inference/mlx/models/llama.py

@@ -0,0 +1,124 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_additive_causal_mask
+from mlx_lm.models.llama import TransformerBlock, ModelArgs
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    super().__post_init__()  # Ensure parent initializations are respected
+
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class LlamaModel(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if self.args.shard.is_first_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+
+    if self.args.shard.is_last_layer():
+      self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+    else:
+      h = inputs
+
+    mask = None
+    if h.shape[1] > 1:
+      mask = create_additive_causal_mask(h.shape[1], cache[0].offset if cache is not None else 0)
+      mask = mask.astype(h.dtype)
+
+    if cache is None:
+      cache = [None] * len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+
+    self.model = LlamaModel(args)
+    if self.args.shard.is_last_layer():
+      if not args.tie_word_embeddings:
+        self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      if self.args.tie_word_embeddings:
+        out = self.model.embed_tokens.as_linear(out)
+      else:
+        out = self.lm_head(out)
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.hidden_size // self.args.num_attention_heads
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 0 - 351
exo/inference/mlx/models/sharded_llama.py

@@ -1,351 +0,0 @@
-from dataclasses import dataclass, field
-from typing import Dict, Optional, Tuple, Union
-
-import mlx.core as mx
-import mlx.nn as nn
-
-from exo.inference.shard import Shard
-from mlx_lm.models.base import BaseModelArgs, KVCache, create_additive_causal_mask
-
-@dataclass
-class NormalModelArgs(BaseModelArgs):
-    model_type: str
-    hidden_size: int
-    num_hidden_layers: int
-    intermediate_size: int
-    num_attention_heads: int
-    rms_norm_eps: float
-    vocab_size: int
-    head_dim: Optional[int] = None
-    max_position_embeddings: Optional[int] = None
-    num_key_value_heads: Optional[int] = None
-    attention_bias: bool = False
-    mlp_bias: bool = False
-    rope_theta: float = 10000
-    rope_traditional: bool = False
-    rope_scaling: Optional[Dict[str, Union[float, str]]] = None
-    tie_word_embeddings: bool = True
-
-    def __post_init__(self):
-        if self.num_key_value_heads is None:
-            self.num_key_value_heads = self.num_attention_heads
-
-        if self.rope_scaling:
-            if not "factor" in self.rope_scaling:
-                raise ValueError(f"rope_scaling must contain 'factor'")
-            rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
-                "rope_type"
-            )
-            if rope_type is None:
-                raise ValueError(
-                    f"rope_scaling must contain either 'type' or 'rope_type'"
-                )
-            if rope_type not in ["linear", "dynamic", "llama3"]:
-                raise ValueError(
-                    "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
-                )
-
-@dataclass
-class ModelArgs(NormalModelArgs):
-    shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
-
-    def __post_init__(self):
-        super().__post_init__()  # Ensure parent initializations are respected
-
-        if isinstance(self.shard, Shard):
-            return
-        if not isinstance(self.shard, dict):
-            raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
-
-        self.shard = Shard(**self.shard)
-
-class DynamicNTKScalingRoPE(nn.Module):
-    """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
-
-    def __init__(
-        self,
-        dims: int,
-        max_position_embeddings: int = 2048,
-        traditional: bool = False,
-        base: float = 10000,
-        scale: float = 1.0,
-        rope_type: str = "default",
-        rope_scaling: dict = None,
-    ):
-        super().__init__()
-        self.dims = dims
-        self.max_position_embeddings = max_position_embeddings
-        self.traditional = traditional
-        self.original_base = base
-        self.scale = scale
-        self.rope_type = rope_type
-        self.rope_scaling = rope_scaling
-        self.base = self.compute_base_freq()
-
-    def compute_base_freq(self):
-        if self.rope_type == "llama3":
-            return self.compute_llama3_base_freq()
-        return self.original_base
-
-    # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
-    def compute_llama3_base_freq(self):
-        factor = self.rope_scaling["factor"]
-        low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
-        high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
-        old_context_len = self.rope_scaling.get(
-            "original_max_position_embeddings",
-            8192,
-        )
-
-        low_freq_wavelen = old_context_len / low_freq_factor
-        high_freq_wavelen = old_context_len / high_freq_factor
-
-        freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
-        wavelens = 2 * mx.pi * freqs
-        new_base_freqs = []
-
-        smooths = (wavelens - high_freq_wavelen) / (
-            low_freq_wavelen - high_freq_wavelen
-        )
-        new_base_freqs = freqs * (1 - smooths) * factor + smooths
-        new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
-        new_base_freqs = mx.where(
-            wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
-        )
-        return new_base_freqs.mean().item()
-
-    def extra_repr(self):
-        return (
-            f"{self.dims}, traditional={self.traditional}, "
-            f"max_position_embeddings={self.max_position_embeddings}, "
-            f"scaling_factor={self.scale}, rope_type={self.rope_type}"
-        )
-
-    def __call__(self, x, offset: int = 0):
-        seq_len = x.shape[1] + offset
-        base = self.base
-        if self.max_position_embeddings and seq_len > self.max_position_embeddings:
-            base *= (
-                (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
-            ) ** (self.dims / (self.dims - 2))
-
-        return mx.fast.rope(
-            x,
-            self.dims,
-            traditional=self.traditional,
-            base=base,
-            scale=self.scale,
-            offset=offset,
-        )
-
-
-def initialize_rope(args: ModelArgs):
-    head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
-
-    rope_scaling = args.rope_scaling
-    rope_type = "default"
-    rope_scale = 1.0
-
-    if rope_scaling is not None:
-        rope_type = (
-            rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
-        )
-        if rope_type == "linear":
-            rope_scale = 1 / rope_scaling["factor"]
-        elif rope_type == "llama3":
-            rope_scale = 1.0  # The scaling is handled internally for llama3
-
-    return DynamicNTKScalingRoPE(
-        dims=head_dim,
-        max_position_embeddings=args.max_position_embeddings,
-        traditional=args.rope_traditional,
-        base=args.rope_theta,
-        scale=rope_scale,
-        rope_type=rope_type,
-        rope_scaling=rope_scaling,
-    )
-
-
-class Attention(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-
-        dim = args.hidden_size
-        self.n_heads = n_heads = args.num_attention_heads
-        self.n_kv_heads = n_kv_heads = args.num_key_value_heads
-
-        self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
-
-        self.scale = head_dim**-0.5
-        if hasattr(args, "attention_bias"):
-            attention_bias = args.attention_bias
-        else:
-            attention_bias = False
-
-        self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
-        self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
-        self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
-        self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
-
-        self.rope = initialize_rope(args)
-
-    def __call__(
-        self,
-        x: mx.array,
-        mask: Optional[mx.array] = None,
-        cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        B, L, D = x.shape
-
-        queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
-
-        # Prepare the queries, keys and values for the attention computation
-        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
-        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
-
-        if cache is not None:
-            queries = self.rope(queries, offset=cache.offset)
-            keys = self.rope(keys, offset=cache.offset)
-            keys, values = cache.update_and_fetch(keys, values)
-        else:
-            queries = self.rope(queries)
-            keys = self.rope(keys)
-
-        output = mx.fast.scaled_dot_product_attention(
-            queries, keys, values, scale=self.scale, mask=mask
-        )
-        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
-        return self.o_proj(output)
-
-
-class MLP(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-
-        dim = args.hidden_size
-        hidden_dim = args.intermediate_size
-        if hasattr(args, "mlp_bias"):
-            mlp_bias = args.mlp_bias
-        else:
-            mlp_bias = False
-
-        self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
-        self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
-        self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
-
-    def __call__(self, x) -> mx.array:
-        return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
-
-
-class TransformerBlock(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-        self.num_attention_heads = args.num_attention_heads
-        self.hidden_size = args.hidden_size
-        self.self_attn = Attention(args)
-        self.mlp = MLP(args)
-        self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-        self.post_attention_layernorm = nn.RMSNorm(
-            args.hidden_size, eps=args.rms_norm_eps
-        )
-        self.args = args
-
-    def __call__(
-        self,
-        x: mx.array,
-        mask: Optional[mx.array] = None,
-        cache: Optional[KVCache] = None,
-    ) -> mx.array:
-        r = self.self_attn(self.input_layernorm(x), mask, cache)
-        h = x + r
-        r = self.mlp(self.post_attention_layernorm(h))
-        out = h + r
-        return out
-
-
-class LlamaModel(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-        self.args = args
-        self.vocab_size = args.vocab_size
-        self.num_hidden_layers = args.num_hidden_layers
-        assert self.vocab_size > 0
-        self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
-        self.layers = [
-            TransformerBlock(args=args) for _ in range(args.shard.end_layer - args.shard.start_layer + 1)
-        ]
-        self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
-
-    def __call__(
-        self,
-        inputs: mx.array,
-        cache=None,
-    ):
-        if self.args.shard.is_first_layer():
-            h = self.embed_tokens(inputs)
-        else:
-            h = inputs
-
-        mask = None
-        if h.shape[1] > 1:
-            mask = create_additive_causal_mask(
-                h.shape[1], cache[0].offset if cache is not None else 0
-            )
-            mask = mask.astype(h.dtype)
-
-        if cache is None:
-            cache = [None] * len(self.layers)
-
-        for layer, c in zip(self.layers, cache):
-            h = layer(h, mask, cache=c)
-
-        if self.args.shard.is_last_layer():
-            return self.norm(h)
-        else:
-            return h
-
-
-class Model(nn.Module):
-    def __init__(self, args: ModelArgs):
-        super().__init__()
-        self.args = args
-        self.model_type = args.model_type
-        self.model = LlamaModel(args)
-        if not args.tie_word_embeddings:
-            self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
-
-    def __call__(
-        self,
-        inputs: mx.array,
-        cache=None,
-    ):
-        out = self.model(inputs, cache)
-
-        if self.args.shard.is_last_layer():
-            if self.args.tie_word_embeddings:
-                out = self.model.embed_tokens.as_linear(out)
-            else:
-                out = self.lm_head(out)
-
-        return out
-
-    def sanitize(self, weights):
-        # Remove unused precomputed rotary freqs
-        return {
-            k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
-        }
-
-    @property
-    def layers(self):
-        return self.model.layers
-
-    @property
-    def head_dim(self):
-        return (
-            self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
-        )
-
-    @property
-    def n_kv_heads(self):
-        return self.args.num_key_value_heads

+ 17 - 16
exo/inference/mlx/sharded_inference_engine.py

@@ -6,24 +6,25 @@ from .sharded_utils import load_shard
 from ..shard import Shard
 from ..shard import Shard
 from typing import Optional
 from typing import Optional
 
 
+
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
-    def __init__(self):
-        self.shard = None
+  def __init__(self):
+    self.shard = None
 
 
-    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        await self.ensure_shard(shard)
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
-        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+    await self.ensure_shard(shard)
+    output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
+    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
 
-    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        await self.ensure_shard(shard)
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
-        return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+    await self.ensure_shard(shard)
+    output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
+    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
 
-    async def ensure_shard(self, shard: Shard):
-        if self.shard == shard:
-            return
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard:
+      return
 
 
-        model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
-        self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
-        self.shard = shard
+    model_shard, self.tokenizer = await load_shard(shard.model_id, shard)
+    self.stateful_sharded_model = StatefulShardedModel(shard, model_shard)
+    self.shard = shard

+ 52 - 55
exo/inference/mlx/sharded_model.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, Generator, Optional, Tuple
+from typing import Dict, Generator, Optional, Tuple
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
@@ -7,67 +7,64 @@ from mlx_lm.sample_utils import top_p_sampling
 
 
 from ..shard import Shard
 from ..shard import Shard
 
 
+
 class StatefulShardedModel:
 class StatefulShardedModel:
-    def __init__(self, shard: Shard, model: nn.Module):
-        self.shard = shard
-        self.model = model
-        self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
+  def __init__(self, shard: Shard, model: nn.Module):
+    self.shard = shard
+    self.model = model
+    self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
 
 
-    def step(
-        self,
-        request_id: str,
-        x,
-        pixel_values=None,
-        temp: float = 0.0,
-        top_p: float = 1.0,
-        logit_bias: Optional[Dict[int, float]] = None,
-    ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-        def sample(logits: mx.array) -> Tuple[mx.array, float]:
-            if logit_bias:
-                indices = mx.array(list(logit_bias.keys()))
-                values = mx.array(list(logit_bias.values()))
-                logits[:, indices] += values
+  def step(
+    self,
+    request_id: str,
+    x,
+    pixel_values=None,
+    temp: float = 0.0,
+    top_p: float = 1.0,
+    logit_bias: Optional[Dict[int, float]] = None,
+  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
+    def sample(logits: mx.array) -> Tuple[mx.array, float]:
+      if logit_bias:
+        indices = mx.array(list(logit_bias.keys()))
+        values = mx.array(list(logit_bias.values()))
+        logits[:, indices] += values
 
 
-            if temp == 0:
-                token = mx.argmax(logits, axis=-1)
-            else:
-                if top_p > 0 and top_p < 1.0:
-                    token = top_p_sampling(logits, top_p, temp)
-                else:
-                    token = mx.random.categorical(logits * (1 / temp))
+      if temp == 0:
+        token = mx.argmax(logits, axis=-1)
+      else:
+        if top_p > 0 and top_p < 1.0:
+          token = top_p_sampling(logits, top_p, temp)
+        else:
+          token = mx.random.categorical(logits * (1 / temp))
 
 
-            return token
+      return token
 
 
-        y = x
+    y = x
 
 
-        if request_id not in self.request_cache:
-            self.init_cache(request_id)
+    if request_id not in self.request_cache:
+      self.init_cache(request_id)
 
 
-        if pixel_values is None:
-            output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
-        else:
-            output = self.model(y, pixel_values=pixel_values, cache=self.request_cache[request_id])
+    if pixel_values is None:
+      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
+    else:
+      output = self.model(y, pixel_values=pixel_values, cache=self.request_cache[request_id])
 
 
-        if self.shard.is_last_layer():
-            logits = output[:, -1, :]
-            y = sample(logits)
-            return y
-        else:
-            return output
+    if self.shard.is_last_layer():
+      logits = output[:, -1, :]
+      y = sample(logits)
+      return y
+    else:
+      return output
 
 
-    def __call__(
-            self,
-            x,
-            temp: float = 0.0,
-            top_p: float = 1.0,
-        logit_bias: Optional[Dict[int, float]] = None,
-    ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-        return self.step(x, temp, top_p, logit_bias)
+  def __call__(
+    self,
+    x,
+    temp: float = 0.0,
+    top_p: float = 1.0,
+    logit_bias: Optional[Dict[int, float]] = None,
+  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
+    return self.step(x, temp, top_p, logit_bias)
 
 
-    def init_cache(self, request_id: str):
-        kv_heads = (
-            [self.model.n_kv_heads] * len(self.model.layers)
-            if isinstance(self.model.n_kv_heads, int)
-            else self.model.n_kv_heads
-        )
-        self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]
+  def init_cache(self, request_id: str):
+    kv_heads = [self.model.n_kv_heads] * len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads
+    self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]

+ 182 - 198
exo/inference/mlx/sharded_utils.py

@@ -20,225 +20,209 @@ from mlx_lm.tuner.utils import apply_lora_layers
 
 
 from ..shard import Shard
 from ..shard import Shard
 
 
+
 class ModelNotFoundError(Exception):
 class ModelNotFoundError(Exception):
-    def __init__(self, message):
-        self.message = message
-        super().__init__(self.message)
+  def __init__(self, message):
+    self.message = message
+    super().__init__(self.message)
+
 
 
 MODEL_REMAPPING = {
 MODEL_REMAPPING = {
-    "sharded_mistral": "sharded_llama",  # mistral is compatible with llama
-    "sharded_phi-msft": "sharded_phixtral",
-    "sharded_llava": "sharded_llava"
+  "mistral": "llama",  # mistral is compatible with llama
+  "phi-msft": "phixtral",
 }
 }
 
 
+
 def _get_classes(config: dict):
 def _get_classes(config: dict):
-    """
-    Retrieve the model and model args classes based on the configuration.
+  """
+  Retrieve the model and model args classes based on the configuration.
 
 
-    Args:
-        config (dict): The model configuration.
+  Args:
+   config (dict): The model configuration.
 
 
-    Returns:
-        A tuple containing the Model class and the ModelArgs class.
-    """
-    model_type = config["model_type"]
-    model_type = MODEL_REMAPPING.get(model_type, model_type)
-    try:
-        arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
-    except ImportError:
-        msg = f"Model type {model_type} not supported."
-        logging.error(msg)
-        raise ValueError(msg)
+  Returns:
+   A tuple containing the Model class and the ModelArgs class.
+  """
+  model_type = config["model_type"]
+  model_type = MODEL_REMAPPING.get(model_type, model_type)
+  try:
+    arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
+  except ImportError:
+    msg = f"Model type {model_type} not supported."
+    logging.error(msg)
+    raise ValueError(msg)
+
+  return arch.Model, arch.ModelArgs
 
 
-    return arch.Model, arch.ModelArgs
 
 
 def load_config(model_path: Path) -> dict:
 def load_config(model_path: Path) -> dict:
-    try:
-        with open(model_path / "config.json", "r") as f:
-            config = json.load(f)
-    except FileNotFoundError:
-        logging.error(f"Config file not found in {model_path}")
-        raise
-    return config
+  try:
+    with open(model_path / "config.json", "r") as f:
+      config = json.load(f)
+  except FileNotFoundError:
+    logging.error(f"Config file not found in {model_path}")
+    raise
+  return config
+
 
 
 def load_model_shard(
 def load_model_shard(
-    model_path: Path,
-    shard: Shard,
-    lazy: bool = False,
-    model_config: dict = {},
+  model_path: Path,
+  shard: Shard,
+  lazy: bool = False,
+  model_config: dict = {},
 ) -> nn.Module:
 ) -> nn.Module:
-    """
-    Load and initialize the model from a given path.
-
-    Args:
-        model_path (Path): The path to load the model from.
-        lazy (bool): If False eval the model parameters to make sure they are
-            loaded in memory before returning, otherwise they will be loaded
-            when needed. Default: ``False``
-        model_config(dict, optional): Configuration parameters for the model.
-            Defaults to an empty dictionary.
-
-    Returns:
-        nn.Module: The loaded and initialized model.
-
-    Raises:
-        FileNotFoundError: If the weight files (.safetensors) are not found.
-        ValueError: If the model class or args class are not found or cannot be instantiated.
-    """
-
-    config = load_config(model_path)
-    config.update(model_config)
-
-    # TODO hack
-    config["model_type"] = f"sharded_{config['model_type']}"
-    config["shard"] = {
-        "model_id": model_path.name,
-        "start_layer": shard.start_layer,
-        "end_layer": shard.end_layer,
-        "n_layers": shard.n_layers
-    }
-
-    weight_files = glob.glob(str(model_path / "model*.safetensors"))
-
-    if not weight_files:
-        # Try weight for back-compat
-        weight_files = glob.glob(str(model_path / "weight*.safetensors"))
-
-    if not weight_files:
-        logging.error(f"No safetensors found in {model_path}")
-        raise FileNotFoundError(f"No safetensors found in {model_path}")
-
-    weights = {}
-    all_weights_keys = set()
-    for wf in weight_files:
-        weights_dict = mx.load(wf)
-        all_weights_keys.update(weights_dict.keys())
-        weights.update({k: v for k, v in weights_dict.items() if not k.startswith("language_model.model.layers.") or shard.start_layer <= int(k.split('.')[3]) <= shard.end_layer})
-        weights.update({k: v for k, v in weights_dict.items() if not k.startswith("model.layers.") or shard.start_layer <= int(k.split('.')[2]) <= shard.end_layer})
-
-    model_class, model_args_class = _get_classes(config=config)
-
-    model_args = model_args_class.from_dict(config)
-    model = model_class(model_args)
-
-    if hasattr(model, "sanitize"):
-        weights = model.sanitize(weights)
-
-    if (quantization := config.get("quantization", None)) is not None:
-        nn.quantize(
-            model,
-            **quantization,
-            class_predicate=None,
-        )
+  """
+  Load and initialize the model from a given path.
 
 
-    filtered_weights = {}
-    for k, v in weights.items():
-        if k.startswith("model.layers."):
-            layer_num = int(k.split('.')[2])
-            if shard.start_layer <= layer_num <= shard.end_layer:
-                new_key = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
-                filtered_weights[new_key] = v
-        elif k.startswith("language_model.model.layers."):
-            layer_num = int(k.split('.')[3])
-            if shard.start_layer <= layer_num <= shard.end_layer:
-                new_key = f"language_model.model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[4:])
-                filtered_weights[new_key] = v
-        else:
-            filtered_weights[k] = v
-    weights = filtered_weights
-
-    model.load_weights(list(weights.items()), strict=False)
-
-    if not lazy:
-        mx.eval(model.parameters())
+  Args:
+   model_path (Path): The path to load the model from.
+   lazy (bool): If False eval the model parameters to make sure they are
+    loaded in memory before returning, otherwise they will be loaded
+    when needed. Default: ``False``
+   model_config(dict, optional): Configuration parameters for the model.
+    Defaults to an empty dictionary.
+
+  Returns:
+   nn.Module: The loaded and initialized model.
+
+  Raises:
+   FileNotFoundError: If the weight files (.safetensors) are not found.
+   ValueError: If the model class or args class are not found or cannot be instantiated.
+  """
+  config = load_config(model_path)
+  config.update(model_config)
+
+  # TODO hack
+  config["shard"] = {
+    "model_id": model_path.name,
+    "start_layer": shard.start_layer,
+    "end_layer": shard.end_layer,
+    "n_layers": shard.n_layers,
+  }
+
+  weight_files = glob.glob(str(model_path / "model*.safetensors"))
+
+  if not weight_files:
+    # Try weight for back-compat
+    weight_files = glob.glob(str(model_path / "weight*.safetensors"))
+
+  if not weight_files:
+    logging.error(f"No safetensors found in {model_path}")
+    raise FileNotFoundError(f"No safetensors found in {model_path}")
+
+  weights = {}
+  for wf in weight_files:
+    weights.update(mx.load(wf))
+
+  model_class, model_args_class = _get_classes(config=config)
+
+  model_args = model_args_class.from_dict(config)
+  model = model_class(model_args)
+
+  if hasattr(model, "sanitize"):
+    weights = model.sanitize(weights)
+
+  if (quantization := config.get("quantization", None)) is not None:
+    nn.quantize(
+      model,
+      **quantization,
+      class_predicate=None,
+    )
+
+  model.load_weights(list(weights.items()))
+
+  if not lazy:
+    mx.eval(model.parameters())
+
+  model.eval()
+  return model
 
 
-    model.eval()
-    return model
 
 
 async def snapshot_download_async(*args, **kwargs):
 async def snapshot_download_async(*args, **kwargs):
-    func = partial(snapshot_download, *args, **kwargs)
-    return await asyncio.get_event_loop().run_in_executor(None, func)
+  func = partial(snapshot_download, *args, **kwargs)
+  return await asyncio.get_event_loop().run_in_executor(None, func)
+
 
 
 async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
 async def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
-    """
-    Ensures the model is available locally. If the path does not exist locally,
-    it is downloaded from the Hugging Face Hub.
-
-    Args:
-        path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
-        revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
-
-    Returns:
-        Path: The path to the model.
-    """
-    model_path = Path(path_or_hf_repo)
-    if not model_path.exists():
-        try:
-            model_path = Path(
-                await snapshot_download_async(
-                    repo_id=path_or_hf_repo,
-                    revision=revision,
-                    allow_patterns=[
-                        "*.json",
-                        "*.safetensors",
-                        "*.py",
-                        "tokenizer.model",
-                        "*.tiktoken",
-                        "*.txt",
-                    ],
-                )
-            )
-        except RepositoryNotFoundError:
-            raise ModelNotFoundError(
-                f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
-                "Please make sure you specified the local path or Hugging Face"
-                " repo id correctly.\nIf you are trying to access a private or"
-                " gated Hugging Face repo, make sure you are authenticated:\n"
-                "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
-            ) from None
-    return model_path
+  """
+  Ensures the model is available locally. If the path does not exist locally,
+  it is downloaded from the Hugging Face Hub.
+
+  Args:
+   path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
+   revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
+
+  Returns:
+   Path: The path to the model.
+  """
+  model_path = Path(path_or_hf_repo)
+  if not model_path.exists():
+    try:
+      model_path = Path(
+        await snapshot_download_async(
+          repo_id=path_or_hf_repo,
+          revision=revision,
+          allow_patterns=[
+            "*.json",
+            "*.safetensors",
+            "*.py",
+            "tokenizer.model",
+            "*.tiktoken",
+            "*.txt",
+          ],
+        )
+      )
+    except RepositoryNotFoundError:
+      raise ModelNotFoundError(
+        f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
+        "Please make sure you specified the local path or Hugging Face"
+        " repo id correctly.\nIf you are trying to access a private or"
+        " gated Hugging Face repo, make sure you are authenticated:\n"
+        "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
+      ) from None
+  return model_path
 
 
 
 
 async def load_shard(
 async def load_shard(
-    path_or_hf_repo: str,
-    shard: Shard,
-    tokenizer_config={},
-    model_config={},
-    adapter_path: Optional[str] = None,
-    lazy: bool = False,
+  path_or_hf_repo: str,
+  shard: Shard,
+  tokenizer_config={},
+  model_config={},
+  adapter_path: Optional[str] = None,
+  lazy: bool = False,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
 ) -> Tuple[nn.Module, TokenizerWrapper]:
-    """
-    Load the model and tokenizer from a given path or a huggingface repository.
-
-    Args:
-        path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
-        tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
-            Defaults to an empty dictionary.
-        model_config(dict, optional): Configuration parameters specifically for the model.
-            Defaults to an empty dictionary.
-        adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
-            to the model. Default: ``None``.
-        lazy (bool): If False eval the model parameters to make sure they are
-            loaded in memory before returning, otherwise they will be loaded
-            when needed. Default: ``False``
-    Returns:
-        Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
-
-    Raises:
-        FileNotFoundError: If config file or safetensors are not found.
-        ValueError: If model class or args class are not found.
-    """
-    model_path = await get_model_path(path_or_hf_repo)
-
-    model = load_model_shard(model_path, shard, lazy, model_config)
-    if adapter_path is not None:
-        model = apply_lora_layers(model, adapter_path)
-        model.eval()
-
-    # TODO: figure out a generic solution
-    if model.model_type == "llava":
-        processor = AutoProcessor.from_pretrained(model_path)
-        return model, processor
-    else:
-        tokenizer = load_tokenizer(model_path, tokenizer_config)
-        return model, tokenizer
+  """
+  Load the model and tokenizer from a given path or a huggingface repository.
+
+  Args:
+   path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
+   tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
+    Defaults to an empty dictionary.
+   model_config(dict, optional): Configuration parameters specifically for the model.
+    Defaults to an empty dictionary.
+   adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
+    to the model. Default: ``None``.
+   lazy (bool): If False eval the model parameters to make sure they are
+    loaded in memory before returning, otherwise they will be loaded
+    when needed. Default: ``False``
+  Returns:
+   Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
+
+  Raises:
+   FileNotFoundError: If config file or safetensors are not found.
+   ValueError: If model class or args class are not found.
+  """
+  model_path = await get_model_path(path_or_hf_repo)
+
+  model = load_model_shard(model_path, shard, lazy, model_config)
+  if adapter_path is not None:
+    model = apply_lora_layers(model, adapter_path)
+    model.eval()
+
+  # TODO: figure out a generic solution
+  if model.model_type == "llava":
+    processor = AutoProcessor.from_pretrained(model_path)
+    return model, processor
+  else:
+    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    return model, tokenizer

+ 5 - 5
exo/inference/mlx/test_sharded_llama.py

@@ -23,17 +23,17 @@ max_tokens = 50
 resp = prompt_tokens
 resp = prompt_tokens
 full_generated_tokens = []
 full_generated_tokens = []
 for _ in range(max_tokens):
 for _ in range(max_tokens):
-    resp = full.step(resp)
-    full_generated_tokens.append(resp.item())
+  resp = full.step(resp)
+  full_generated_tokens.append(resp.item())
 
 
 print("full response: ", full_tokenizer.decode(full_generated_tokens))
 print("full response: ", full_tokenizer.decode(full_generated_tokens))
 
 
 sharded_generated_tokens = []
 sharded_generated_tokens = []
 sharded_resp = prompt_tokens
 sharded_resp = prompt_tokens
 for _ in range(max_tokens):
 for _ in range(max_tokens):
-    resp1 = m1.step(sharded_resp)
-    sharded_resp = m2.step(resp1)
-    sharded_generated_tokens.append(sharded_resp.item())
+  resp1 = m1.step(sharded_resp)
+  sharded_resp = m2.step(resp1)
+  sharded_generated_tokens.append(sharded_resp.item())
 
 
 print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
 print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
 
 

+ 29 - 28
exo/inference/mlx/test_sharded_model.py

@@ -1,36 +1,37 @@
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
-from exo.inference.mlx.sharded_model import StatefulShardedModel
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from typing import Optional
 from typing import Optional
 import numpy as np
 import numpy as np
 
 
+
 class DummyModel(nn.Module):
 class DummyModel(nn.Module):
-    def __init__(self, shard: Optional[Shard] = None):
-        self.shard = shard
-        self.layers = [
-            nn.Linear(8, 128),
-            nn.Linear(128, 128),
-            nn.Linear(128, 128),
-            nn.Linear(128, 128),
-            nn.Linear(128, 8),
-        ]
-
-        self.n_kv_heads = 4
-        self.head_dim = 4
-
-    def __call__(self, x, cache=None):
-        if self.shard:
-            for layer in self.layers[self.shard.start_layer:self.shard.end_layer+1]:
-                x = layer(x)
-            if self.shard.is_last_layer():
-                x =  x.reshape((1, 2, 4))
-        else:
-            for layer in self.layers:
-                x = layer(x)
-            x = x.reshape((1, 2, 4))
-
-        return x
+  def __init__(self, shard: Optional[Shard] = None):
+    self.shard = shard
+    self.layers = [
+      nn.Linear(8, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 128),
+      nn.Linear(128, 8),
+    ]
+
+    self.n_kv_heads = 4
+    self.head_dim = 4
+
+  def __call__(self, x, cache=None):
+    if self.shard:
+      for layer in self.layers[self.shard.start_layer : self.shard.end_layer + 1]:
+        x = layer(x)
+      if self.shard.is_last_layer():
+        x = x.reshape((1, 2, 4))
+    else:
+      for layer in self.layers:
+        x = layer(x)
+      x = x.reshape((1, 2, 4))
+
+    return x
+
 
 
 model = DummyModel()
 model = DummyModel()
 model.save_weights("./test_weights.npz")
 model.save_weights("./test_weights.npz")
@@ -44,8 +45,8 @@ model.load_weights("./test_weights.npz")
 sharded_model1.load_weights("./test_weights.npz")
 sharded_model1.load_weights("./test_weights.npz")
 sharded_model2.load_weights("./test_weights.npz")
 sharded_model2.load_weights("./test_weights.npz")
 
 
-fullresp = model(mx.array([1,2,3,4,5,6,7,8]))
-resp1 = sharded_model1(mx.array([1,2,3,4,5,6,7,8]))
+fullresp = model(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
+resp1 = sharded_model1(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
 resp2 = sharded_model2(resp1)
 resp2 = sharded_model2(resp1)
 
 
 assert np.all(np.array(fullresp) == np.array(resp2))
 assert np.all(np.array(fullresp) == np.array(resp2))

+ 18 - 17
exo/inference/shard.py

@@ -1,25 +1,26 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 
 
+
 @dataclass
 @dataclass
 class Shard:
 class Shard:
-    model_id: str
-    start_layer: int
-    end_layer: int
-    n_layers: int
+  model_id: str
+  start_layer: int
+  end_layer: int
+  n_layers: int
 
 
-    def is_first_layer(self) -> bool:
-        return self.start_layer == 0
+  def is_first_layer(self) -> bool:
+    return self.start_layer == 0
 
 
-    def is_last_layer(self) -> bool:
-        return self.end_layer == self.n_layers - 1
+  def is_last_layer(self) -> bool:
+    return self.end_layer == self.n_layers - 1
 
 
-    def get_layer_count(self) -> int:
-        return self.end_layer - self.start_layer + 1
+  def get_layer_count(self) -> int:
+    return self.end_layer - self.start_layer + 1
 
 
-    def to_dict(self) -> dict:
-        return {
-            "model_id": self.model_id,
-            "start_layer": self.start_layer,
-            "end_layer": self.end_layer,
-            "n_layers": self.n_layers
-        }
+  def to_dict(self) -> dict:
+    return {
+      "model_id": self.model_id,
+      "start_layer": self.start_layer,
+      "end_layer": self.end_layer,
+      "n_layers": self.n_layers,
+    }

+ 35 - 12
exo/inference/test_inference_engine.py

@@ -1,29 +1,52 @@
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
-from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
 import asyncio
 import asyncio
 import numpy as np
 import numpy as np
 
 
+
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
-    prompt = "In a single word only, what is the last name of the current president of the USA?"
-    resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-    next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
+  prompt = "In a single word only, what is the last name of the current president of the USA?"
+  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+    "A",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
+    input_data=resp_full,
+    inference_state=inference_state_full,
+  )
+
+  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp1,
+    inference_state=inference_state_1,
+  )
+  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
+    input_data=resp2,
+    inference_state=inference_state_2,
+  )
+  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+    "B",
+    shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
+    input_data=resp3,
+    inference_state=inference_state_3,
+  )
 
 
-    resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-    resp2, inference_state_2, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state_1)
-    resp3, inference_state_3, _ = await inference_engine_1.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=resp2, inference_state=inference_state_2)
-    resp4, inference_state_4, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, inference_state=inference_state_3)
+  assert np.array_equal(resp_full, resp2)
+  assert np.array_equal(next_resp_full, resp4)
 
 
-    assert np.array_equal(resp_full, resp2)
-    assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(test_inference_engine(
+asyncio.run(
+  test_inference_engine(
     MLXDynamicShardInferenceEngine(),
     MLXDynamicShardInferenceEngine(),
     MLXDynamicShardInferenceEngine(),
     MLXDynamicShardInferenceEngine(),
     "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
     "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
-))
+  )
+)
 
 
 # TODO: Need more memory or a smaller model
 # TODO: Need more memory or a smaller model
 # asyncio.run(test_inference_engine(
 # asyncio.run(test_inference_engine(

+ 180 - 93
exo/inference/tinygrad/inference.py

@@ -2,12 +2,12 @@ import asyncio
 from functools import partial
 from functools import partial
 from pathlib import Path
 from pathlib import Path
 from typing import List, Optional, Union
 from typing import List, Optional, Union
-import json, argparse, random, time
+import json
 import tiktoken
 import tiktoken
 from tiktoken.load import load_tiktoken_bpe
 from tiktoken.load import load_tiktoken_bpe
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
 from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
-from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
-from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
+from tinygrad.nn.state import safe_load, torch_load, load_state_dict
+from tinygrad import Tensor, nn, Context, GlobalCounters
 from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from tinygrad.helpers import DEBUG, tqdm, _cache_dir, fetch
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
@@ -16,17 +16,37 @@ import os
 
 
 MODEL_PARAMS = {
 MODEL_PARAMS = {
   "8B": {
   "8B": {
-    "args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336},
-    "files": 1
+    "args": {
+      "dim": 4096,
+      "n_heads": 32,
+      "n_kv_heads": 8,
+      "n_layers": 32,
+      "norm_eps": 1e-5,
+      "rope_theta": 500000,
+      "vocab_size": 128256,
+      "hidden_dim": 14336,
+    },
+    "files": 1,
   },
   },
   "70B": {
   "70B": {
-    "args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256,  "hidden_dim": 28672},
-    "files": 8
-  }
+    "args": {
+      "dim": 8192,
+      "n_heads": 64,
+      "n_kv_heads": 8,
+      "n_layers": 80,
+      "norm_eps": 1e-5,
+      "rope_theta": 500000,
+      "vocab_size": 128256,
+      "hidden_dim": 28672,
+    },
+    "files": 8,
+  },
 }
 }
 
 
+
 class Tokenizer:
 class Tokenizer:
   pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
   pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+
   def __init__(self, model_path: str):
   def __init__(self, model_path: str):
     mergeable_ranks = load_tiktoken_bpe(model_path)
     mergeable_ranks = load_tiktoken_bpe(model_path)
     self.num_base_tokens = len(mergeable_ranks)
     self.num_base_tokens = len(mergeable_ranks)
@@ -41,29 +61,36 @@ class Tokenizer:
       "<|end_header_id|>",
       "<|end_header_id|>",
       "<|reserved_special_token_4|>",
       "<|reserved_special_token_4|>",
       "<|eot_id|>",
       "<|eot_id|>",
-    ] + [
-      f"<|reserved_special_token_{i}|>"
-      for i in range(5, 256 - 5)
-    ]
+    ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
     self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
     self.special_tokens = {token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}
 
 
     self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
     self.model = tiktoken.Encoding(name=model_path, pat_str=self.pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens)
 
 
   @property
   @property
-  def bos_id(self): return self.special_tokens["<|begin_of_text|>"]
+  def bos_id(self):
+    return self.special_tokens["<|begin_of_text|>"]
+
   @property
   @property
-  def stop_tokens(self): return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
+  def stop_tokens(self):
+    return {self.special_tokens["<|end_of_text|>"], self.special_tokens["<|eot_id|>"]}
 
 
   def decode(self, toks):
   def decode(self, toks):
-     return self.model.decode([t for t in toks if t < self.num_base_tokens])
+    return self.model.decode([t for t in toks if t < self.num_base_tokens])
+
   def encode(self, text, allow_special=False):
   def encode(self, text, allow_special=False):
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
     return self.model.encode(text, allowed_special="all" if allow_special else set(), disallowed_special=set())
 
 
+
 # **** helper functions ****
 # **** helper functions ****
-async def fetch_async(url: str, name: Optional[Union[Path, str]] = None, subdir: Optional[str] = None,
-                      allow_caching=not os.getenv("DISABLE_HTTP_CACHE")) -> Path:
-    func = partial(fetch, url, name, subdir, allow_caching)
-    return await asyncio.get_event_loop().run_in_executor(None, func)
+async def fetch_async(
+  url: str,
+  name: Optional[Union[Path, str]] = None,
+  subdir: Optional[str] = None,
+  allow_caching=not os.getenv("DISABLE_HTTP_CACHE"),
+) -> Path:
+  func = partial(fetch, url, name, subdir, allow_caching)
+  return await asyncio.get_event_loop().run_in_executor(None, func)
+
 
 
 def concat_weights(models, device=None):
 def concat_weights(models, device=None):
   def convert(name) -> Tensor:
   def convert(name) -> Tensor:
@@ -73,11 +100,14 @@ def concat_weights(models, device=None):
     axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
     axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
     lazy_tensors = [data.to(device=device) for data in disk_tensors]
     lazy_tensors = [data.to(device=device) for data in disk_tensors]
     return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
     return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
+
   return {name: convert(name) for name in {name: None for model in models for name in model}}
   return {name: convert(name) for name in {name: None for model in models for name in model}}
 
 
-def load(fn:str):
-  if fn.endswith('.index.json'):
-    with open(fn) as fp: weight_map = json.load(fp)['weight_map']
+
+def load(fn: str):
+  if fn.endswith(".index.json"):
+    with open(fn) as fp:
+      weight_map = json.load(fp)["weight_map"]
     parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
     parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
     return {k: parts[n][k] for k, n in weight_map.items()}
     return {k: parts[n][k] for k, n in weight_map.items()}
   elif fn.endswith(".safetensors"):
   elif fn.endswith(".safetensors"):
@@ -85,6 +115,7 @@ def load(fn:str):
   else:
   else:
     return torch_load(fn)
     return torch_load(fn)
 
 
+
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=None, device=None):
   # build model
   # build model
   linear = nn.Linear
   linear = nn.Linear
@@ -93,44 +124,67 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", quantize=
 
 
   # load weights
   # load weights
   if model_path.is_dir():
   if model_path.is_dir():
-    if (model_path / "model.safetensors.index.json").exists(): weights = load(str(model_path / "model.safetensors.index.json"))
-    elif (model_path / "model.safetensors").exists(): weights = load(str(model_path / "model.safetensors"))
-    else: weights = concat_weights([load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
+    if (model_path / "model.safetensors.index.json").exists():
+      weights = load(str(model_path / "model.safetensors.index.json"))
+    elif (model_path / "model.safetensors").exists():
+      weights = load(str(model_path / "model.safetensors"))
+    else:
+      weights = concat_weights(
+        [load(str(model_path / f"consolidated.{i:02d}.pth")) for i in range(MODEL_PARAMS[model_size]["files"])],
+        device[0] if isinstance(device, tuple) else device,
+      )
   else:
   else:
     weights = load(str(model_path))
     weights = load(str(model_path))
   if "model.embed_tokens.weight" in weights:
   if "model.embed_tokens.weight" in weights:
-    weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"], shard=shard)
+    weights = convert_from_huggingface(
+      weights,
+      model,
+      MODEL_PARAMS[model_size]["args"]["n_heads"],
+      MODEL_PARAMS[model_size]["args"]["n_kv_heads"],
+      shard=shard,
+    )
   weights = fix_bf16(weights)
   weights = fix_bf16(weights)
 
 
   with Context(BEAM=0):
   with Context(BEAM=0):
     # quantize
     # quantize
     if quantize is not None:
     if quantize is not None:
       weights = linear.quantize(weights, device)
       weights = linear.quantize(weights, device)
-      for _,v in weights.items(): v.realize()
+      for _, v in weights.items():
+        v.realize()
 
 
     # shard
     # shard
     if isinstance(device, tuple):
     if isinstance(device, tuple):
-      for k,v in nn.state.get_state_dict(model).items():
-        if 'scale' in k: v.shard_(device, axis=None)  # from quantized
-        elif '.attention.' in k: v.shard_(device, axis=-1)
-        elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
-        elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
-        elif '.feed_forward.' in k: v.shard_(device, axis=-1)
-        elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
-        elif 'output.weight' in k: v.shard_(device, axis=0)
-        else: v.shard_(device, axis=None)
+      for k, v in nn.state.get_state_dict(model).items():
+        if "scale" in k:
+          v.shard_(device, axis=None)  # from quantized
+        elif ".attention." in k:
+          v.shard_(device, axis=-1)
+        elif ".feed_forward.w1." in k:
+          v.shard_(device, axis=0)
+        elif ".feed_forward.w3." in k:
+          v.shard_(device, axis=0)
+        elif ".feed_forward." in k:
+          v.shard_(device, axis=-1)
+        elif "tok_embeddings.weight" in k:
+          v.shard_(device, axis=0)
+        elif "output.weight" in k:
+          v.shard_(device, axis=0)
+        else:
+          v.shard_(device, axis=None)
 
 
     # replace weights in model
     # replace weights in model
     load_state_dict(model, weights, strict=False, consume=True)
     load_state_dict(model, weights, strict=False, consume=True)
   return model
   return model
 
 
+
 # default settings
 # default settings
-TEMPERATURE = 0 # 0.85
+TEMPERATURE = 0  # 0.85
 TOP_K = 25
 TOP_K = 25
 TOP_P = 0.9
 TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 ALPHA_P = 0.0
 
 
+
 def prefill(model, toks, start_pos=0):
 def prefill(model, toks, start_pos=0):
   # prefill the model
   # prefill the model
   for tok in tqdm(toks):
   for tok in tqdm(toks):
@@ -139,71 +193,104 @@ def prefill(model, toks, start_pos=0):
     start_pos += 1
     start_pos += 1
   return start_pos
   return start_pos
 
 
+
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
-    def __init__(self):
-        self.shard = None
+  def __init__(self):
+    self.shard = None
+
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+    # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
+    await self.ensure_shard(shard)
+    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
 
-    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
-        await self.ensure_shard(shard)
-        start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    toks = self.tokenizer.encode(prompt)
+    start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
+    last_tok = toks[-1]
 
 
-        toks = self.tokenizer.encode(prompt)
-        start_pos = prefill(self.model, toks[:-1], start_pos=start_pos)
-        last_tok = toks[-1]
+    output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
+    if output_data.size == 1:
+      start_pos += 1
 
 
-        output_data = np.array([self.model(Tensor([[last_tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-        if output_data.size == 1:
-           start_pos += 1
+    return (
+      output_data,
+      json.dumps({"start_pos": start_pos}),
+      output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens,
+    )
 
 
-        return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+    await self.ensure_shard(shard)
+    start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
 
-    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-        await self.ensure_shard(shard)
-        start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
+    output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
+    if output_data.size == 1:
+      start_pos += 1
 
 
-        output_data: np.ndarray = np.array([self.model(Tensor([input_data]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P).tolist()])
-        if output_data.size == 1:
-           start_pos += 1
+    return (
+      output_data,
+      json.dumps({"start_pos": start_pos}),
+      output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens,
+    )
 
 
-        return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard:
+      return
 
 
-    async def ensure_shard(self, shard: Shard):
-        if self.shard == shard:
-            return
+    model_path = Path(shard.model_id)
+    models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
+    model_path = models_dir / shard.model_id
+    size = "8B"
+    if Path(model_path / "model.safetensors.index.json").exists():
+      model = model_path
+    else:
 
 
-        model_path = Path(shard.model_id)
-        models_dir = Path(_cache_dir) / "tinygrad" / "downloads"
-        model_path = models_dir / shard.model_id
+      if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
+      if shard.model_id.lower().find("llama3-8b-sfr") != -1:
+        await fetch_async(
+          "https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model",
+          "tokenizer.model",
+          subdir=shard.model_id,
+        )
+        await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors",
+          "model-00001-of-00004.safetensors",
+          subdir=shard.model_id,
+        )
+        await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors",
+          "model-00002-of-00004.safetensors",
+          subdir=shard.model_id,
+        )
+        await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors",
+          "model-00003-of-00004.safetensors",
+          subdir=shard.model_id,
+        )
+        await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors",
+          "model-00004-of-00004.safetensors",
+          subdir=shard.model_id,
+        )
+        model = await fetch_async(
+          "https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json",
+          "model.safetensors.index.json",
+          subdir=shard.model_id,
+        )
         size = "8B"
         size = "8B"
-        if Path(model_path / "model.safetensors.index.json").exists():
-            model = model_path
-        else:
+      elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
+        raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
+        # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
+        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
+        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
+        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
+        # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
+        # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
+        # size = "70B"
+      else:
+        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
+
+    model = build_transformer(model_path, shard=shard, model_size=size)
+    tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
 
 
-            if DEBUG >= 2: print(f"Downloading tinygrad model {shard.model_id}...")
-            if shard.model_id.lower().find("llama3-8b-sfr") != -1:
-                await fetch_async("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-                await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-                model = await fetch_async("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
-                size = "8B"
-            elif shard.model_id.lower().find("llama3-70b-sfr") != -1:
-                raise NotImplementedError("llama3-70b-sfr is not implemented for tinygrad")
-                # fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-70B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=shard.model_id)
-                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00001-of-00004.safetensors", "model-00001-of-00004.safetensors", subdir=shard.model_id)
-                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00002-of-00004.safetensors", "model-00002-of-00004.safetensors", subdir=shard.model_id)
-                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00003-of-00004.safetensors", "model-00003-of-00004.safetensors", subdir=shard.model_id)
-                # fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir=shard.model_id)
-                # model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir=shard.model_id)
-                # size = "70B"
-            else:
-                raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {shard.model_id}")
-
-        model = build_transformer(model_path, shard=shard, model_size=size)
-        tokenizer = Tokenizer(str((model_path if model_path.is_dir() else model_path.parent) / "tokenizer.model"))
-
-        self.shard = shard
-        self.model = model
-        self.tokenizer = tokenizer
+    self.shard = shard
+    self.model = model
+    self.tokenizer = tokenizer

+ 100 - 36
exo/inference/tinygrad/models/llama.py

@@ -3,21 +3,24 @@ from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
+
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
-  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
+  freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   # TODO: move dtype outside this
-  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim//2, 2)
+  return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
+
 
 
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 # (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
 def complex_mult(A, c, d):
 def complex_mult(A, c, d):
-  a,b = A[..., 0:1], A[..., 1:2]
-  ro = a*c - b*d
-  co = a*d + b*c
+  a, b = A[..., 0:1], A[..., 1:2]
+  ro = a * c - b * d
+  co = a * d + b * c
   return ro.cat(co, dim=-1)
   return ro.cat(co, dim=-1)
 
 
-def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Tensor]:
+
+def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xq = xq.reshape(*xq.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
   xk = xk.reshape(*xk.shape[0:-1], -1, 2)
@@ -27,16 +30,19 @@ def apply_rotary_emb(xq:Tensor, xk:Tensor, freqs_cis:Tensor) -> Tuple[Tensor, Te
   xk_out = complex_mult(xk, c, d)
   xk_out = complex_mult(xk, c, d)
   return xq_out.flatten(3), xk_out.flatten(3)
   return xq_out.flatten(3), xk_out.flatten(3)
 
 
-def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
+
+def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   bs, seqlen, n_kv_heads, head_dim = x.shape
   bs, seqlen, n_kv_heads, head_dim = x.shape
-  if n_rep == 1: return x
+  if n_rep == 1:
+    return x
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
 
 
+
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
-    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
+    self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads  # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
     self.head_dim = dim // n_heads
     self.head_dim = dim // n_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.n_rep = self.n_heads // self.n_kv_heads
     self.max_context = max_context
     self.max_context = max_context
@@ -46,7 +52,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
     self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
     xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
     xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
     xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
@@ -64,10 +70,10 @@ class Attention:
 
 
     # update the cache
     # update the cache
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
     assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
+    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
+    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -75,26 +81,39 @@ class Attention:
     attn = attn.reshape(bsz, seqlen, -1)
     attn = attn.reshape(bsz, seqlen, -1)
     return self.wo(attn)
     return self.wo(attn)
 
 
+
 class FeedForward:
 class FeedForward:
-  def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
+  def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w1 = linear(dim, hidden_dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
     self.w2 = linear(hidden_dim, dim, bias=False)
-    self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
+    self.w3 = linear(dim, hidden_dim, bias=False)  # the gate in Gated Linear Unit
+
+  def __call__(self, x: Tensor) -> Tensor:
+    return self.w2(self.w1(x).silu() * self.w3(x))  # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
-  def __call__(self, x:Tensor) -> Tensor:
-    return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
 
 
 class TransformerBlock:
 class TransformerBlock:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, feed_forward=FeedForward):
+  def __init__(
+    self,
+    dim: int,
+    hidden_dim: int,
+    n_heads: int,
+    n_kv_heads: int,
+    norm_eps: float,
+    max_context: int,
+    linear=nn.Linear,
+    feed_forward=FeedForward,
+  ):
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.feed_forward = feed_forward(dim, hidden_dim, linear)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
+
 # standard openai sampling
 # standard openai sampling
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert logits.ndim == 1, "only works on 1d tensors"
@@ -102,7 +121,8 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
 
 
   # if temperature is very low just use argmax
   # if temperature is very low just use argmax
-  if temp < 1e-6: return logits.argmax()
+  if temp < 1e-6:
+    return logits.argmax()
 
 
   # alpha sampling
   # alpha sampling
   if af or ap:
   if af or ap:
@@ -116,10 +136,16 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   # softmax
   # softmax
   t = (logits / temp).softmax()
   t = (logits / temp).softmax()
 
 
-  counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
+  counter, counter2 = (
+    Tensor.arange(t.numel(), device=logits.device).contiguous(),
+    Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous(),
+  )
   # top k
   # top k
   if k:
   if k:
-    output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
+    output, output_indices = (
+      Tensor.zeros(k, device=logits.device).contiguous(),
+      Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous(),
+    )
     for i in range(k):
     for i in range(k):
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       t_argmax = (t.numel() - ((t == (t_max := t.max())) * counter2).max() - 1).cast(dtypes.default_int)
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
       output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
@@ -144,8 +170,24 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
 
 
   return output_token
   return output_token
 
 
+
 class Transformer:
 class Transformer:
-  def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, shard: Shard, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True, feed_forward=FeedForward):
+  def __init__(
+    self,
+    dim: int,
+    hidden_dim: int,
+    n_heads: int,
+    n_layers: int,
+    norm_eps: float,
+    vocab_size,
+    shard: Shard,
+    linear=nn.Linear,
+    n_kv_heads=None,
+    rope_theta=10000,
+    max_context=1024,
+    jit=True,
+    feed_forward=FeedForward,
+  ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(shard.end_layer - shard.start_layer + 1)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
@@ -155,13 +197,22 @@ class Transformer:
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.forward_jit = TinyJit(self.forward) if jit else None
     self.shard = shard
     self.shard = shard
 
 
-  def forward(self, h:Tensor, start_pos:Union[Variable,int], temperature:float, top_k:int, top_p:float, alpha_f:float, alpha_p:float):
+  def forward(
+    self,
+    h: Tensor,
+    start_pos: Union[Variable, int],
+    temperature: float,
+    top_k: int,
+    top_p: float,
+    alpha_f: float,
+    alpha_p: float,
+  ):
     seqlen = h.shape[1]
     seqlen = h.shape[1]
-    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
+    freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
 
 
     if self.shard.is_first_layer():
     if self.shard.is_first_layer():
       h = self.tok_embeddings(h)
       h = self.tok_embeddings(h)
-    mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos+1).realize() if seqlen > 1 else None
+    mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=h.dtype, device=h.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
 
     for i, layer in enumerate(self.layers):
     for i, layer in enumerate(self.layers):
       h = layer(h, start_pos, freqs_cis, mask)
       h = layer(h, start_pos, freqs_cis, mask)
@@ -169,12 +220,21 @@ class Transformer:
       #   print(f"layer {i}: {str(h.numpy())[:60]}")
       #   print(f"layer {i}: {str(h.numpy())[:60]}")
 
 
     if self.shard.is_last_layer():
     if self.shard.is_last_layer():
-        logits = self.output(self.norm(h)).float()[:, -1, :]
-        return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+      logits = self.output(self.norm(h)).float()[:, -1, :]
+      return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
     else:
     else:
       return h.realize()
       return h.realize()
 
 
-  def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0, top_k:int=0, top_p:float=0.8, alpha_f:float=0.0, alpha_p:float=0.0):
+  def __call__(
+    self,
+    tokens: Tensor,
+    start_pos: Variable,
+    temperature: float = 0.0,
+    top_k: int = 0,
+    top_p: float = 0.8,
+    alpha_f: float = 0.0,
+    alpha_p: float = 0.0,
+  ):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
     # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
     # if tokens.shape[0:2] == (1,1) and self.forward_jit is not None:
     #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
     #   return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
@@ -185,9 +245,11 @@ class Transformer:
       if hasattr(layer.attention, "cache_kv"):
       if hasattr(layer.attention, "cache_kv"):
         layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
         layer.attention.cache_kv = layer.attention.cache_kv.zeros_like()
 
 
+
 # *** helpers ***
 # *** helpers ***
 
 
-def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
+
+def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int, shard: Shard):
   def permute(v: Tensor, n_heads: int):
   def permute(v: Tensor, n_heads: int):
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
     return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
 
 
@@ -202,12 +264,13 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
   }
   }
   sd = {}
   sd = {}
   for k, v in weights.items():
   for k, v in weights.items():
-    if ".rotary_emb." in k: continue
+    if ".rotary_emb." in k:
+      continue
     v = v.to(Device.DEFAULT)
     v = v.to(Device.DEFAULT)
     if "model.layers" in k:
     if "model.layers" in k:
-      layer_num = int(k.split('.')[2])
+      layer_num = int(k.split(".")[2])
       if shard.start_layer <= layer_num <= shard.end_layer:
       if shard.start_layer <= layer_num <= shard.end_layer:
-          k = f"model.layers.{layer_num - shard.start_layer}." + '.'.join(k.split('.')[3:])
+        k = f"model.layers.{layer_num - shard.start_layer}." + ".".join(k.split(".")[3:])
       else:
       else:
         continue
         continue
 
 
@@ -218,9 +281,10 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
     sd[keymap[k]] = v
     sd[keymap[k]] = v
   return sd
   return sd
 
 
-def fix_bf16(weights:Dict[Any, Tensor]):
+
+def fix_bf16(weights: Dict[Any, Tensor]):
   if getenv("SUPPORT_BF16", 1):
   if getenv("SUPPORT_BF16", 1):
     # TODO: without casting to float16, 70B llama OOM on tinybox.
     # TODO: without casting to float16, 70B llama OOM on tinybox.
-    return {k:v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+    return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
   # TODO: check if device supports bf16
   # TODO: check if device supports bf16
-  return {k:v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k,v in weights.items()}
+  return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

+ 1 - 1
exo/networking/__init__.py

@@ -2,4 +2,4 @@ from .discovery import Discovery
 from .peer_handle import PeerHandle
 from .peer_handle import PeerHandle
 from .server import Server
 from .server import Server
 
 
-__all__ = ['Discovery', 'PeerHandle', 'Server']
+__all__ = ["Discovery", "PeerHandle", "Server"]

+ 10 - 9
exo/networking/discovery.py

@@ -2,15 +2,16 @@ from abc import ABC, abstractmethod
 from typing import List
 from typing import List
 from .peer_handle import PeerHandle
 from .peer_handle import PeerHandle
 
 
+
 class Discovery(ABC):
 class Discovery(ABC):
-    @abstractmethod
-    async def start(self) -> None:
-        pass
+  @abstractmethod
+  async def start(self) -> None:
+    pass
 
 
-    @abstractmethod
-    async def stop(self) -> None:
-        pass
+  @abstractmethod
+  async def stop(self) -> None:
+    pass
 
 
-    @abstractmethod
-    async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-        pass
+  @abstractmethod
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    pass

+ 172 - 120
exo/networking/grpc/grpc_discovery.py

@@ -9,130 +9,182 @@ from .grpc_peer_handle import GRPCPeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo import DEBUG_DISCOVERY
 from exo import DEBUG_DISCOVERY
 
 
+
 class ListenProtocol(asyncio.DatagramProtocol):
 class ListenProtocol(asyncio.DatagramProtocol):
-    def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
-        super().__init__()
-        self.on_message = on_message
-        self.loop = asyncio.get_event_loop()
+  def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
+    super().__init__()
+    self.on_message = on_message
+    self.loop = asyncio.get_event_loop()
 
 
-    def connection_made(self, transport):
-        self.transport = transport
+  def connection_made(self, transport):
+    self.transport = transport
 
 
-    def datagram_received(self, data, addr):
-        asyncio.create_task(self.on_message(data, addr))
+  def datagram_received(self, data, addr):
+    asyncio.create_task(self.on_message(data, addr))
 
 
 
 
 class GRPCDiscovery(Discovery):
 class GRPCDiscovery(Discovery):
-    def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES, discovery_timeout: int = 30):
-        self.node_id = node_id
-        self.node_port = node_port
-        self.device_capabilities = device_capabilities
-        self.listen_port = listen_port
-        self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
-        self.broadcast_interval = broadcast_interval
-        self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
-        self.broadcast_task = None
-        self.listen_task = None
-        self.cleanup_task = None
-        self.discovery_timeout = discovery_timeout
-
-    async def start(self):
-        self.device_capabilities = device_capabilities()
-        self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
-        self.listen_task = asyncio.create_task(self.task_listen_for_peers())
-        self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
-
-    async def stop(self):
-        if self.broadcast_task:
-            self.broadcast_task.cancel()
-        if self.listen_task:
-            self.listen_task.cancel()
-        if self.cleanup_task:
-            self.cleanup_task.cancel()
-        if self.broadcast_task or self.listen_task or self.cleanup_task:
-            await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
-
-    async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
-        if DEBUG_DISCOVERY >= 2: print("Starting peer discovery process...")
-
+  def __init__(
+    self,
+    node_id: str,
+    node_port: int,
+    listen_port: int,
+    broadcast_port: int = None,
+    broadcast_interval: int = 1,
+    device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
+    discovery_timeout: int = 30,
+  ):
+    self.node_id = node_id
+    self.node_port = node_port
+    self.device_capabilities = device_capabilities
+    self.listen_port = listen_port
+    self.broadcast_port = broadcast_port if broadcast_port is not None else listen_port
+    self.broadcast_interval = broadcast_interval
+    self.known_peers: Dict[str, Tuple[GRPCPeerHandle, float, float]] = {}
+    self.broadcast_task = None
+    self.listen_task = None
+    self.cleanup_task = None
+    self.discovery_timeout = discovery_timeout
+
+  async def start(self):
+    self.device_capabilities = device_capabilities()
+    self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
+    self.listen_task = asyncio.create_task(self.task_listen_for_peers())
+    self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
+
+  async def stop(self):
+    if self.broadcast_task:
+      self.broadcast_task.cancel()
+    if self.listen_task:
+      self.listen_task.cancel()
+    if self.cleanup_task:
+      self.cleanup_task.cancel()
+    if self.broadcast_task or self.listen_task or self.cleanup_task:
+      await asyncio.gather(self.broadcast_task, self.listen_task, self.cleanup_task, return_exceptions=True)
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if DEBUG_DISCOVERY >= 2:
+      print("Starting peer discovery process...")
+
+    if wait_for_peers > 0:
+      while len(self.known_peers) == 0:
+        if DEBUG_DISCOVERY >= 2:
+          print("No peers discovered yet, retrying in 1 second...")
+        await asyncio.sleep(1)  # Keep trying to find peers
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
+
+    grace_period = 5  # seconds
+    while True:
+      initial_peer_count = len(self.known_peers)
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
+      if len(self.known_peers) == initial_peer_count:
         if wait_for_peers > 0:
         if wait_for_peers > 0:
-            while len(self.known_peers) == 0:
-                if DEBUG_DISCOVERY >= 2: print("No peers discovered yet, retrying in 1 second...")
-                await asyncio.sleep(1)  # Keep trying to find peers
-            if DEBUG_DISCOVERY >= 2: print(f"Discovered first peer: {next(iter(self.known_peers.values()))}")
-
-        grace_period = 5  # seconds
-        while True:
-            initial_peer_count = len(self.known_peers)
-            if DEBUG_DISCOVERY >= 2: print(f"Current number of known peers: {initial_peer_count}. Waiting {grace_period} seconds to discover more...")
-            if len(self.known_peers) == initial_peer_count:
-                if wait_for_peers > 0:
-                    await asyncio.sleep(grace_period)
-                    if DEBUG_DISCOVERY >= 2: print(f"Waiting additional {wait_for_peers} seconds for more peers.")
-                    wait_for_peers = 0
-                else:
-                    if DEBUG_DISCOVERY >= 2: print("No new peers discovered in the last grace period. Ending discovery process.")
-                    break  # No new peers found in the grace period, we are done
-
-        return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
-
-    async def task_broadcast_presence(self):
-        transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-                    lambda: asyncio.DatagramProtocol(),
-                    local_addr=('0.0.0.0', 0),
-                    family=socket.AF_INET)
-        sock = transport.get_extra_info('socket')
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
-
-        message = json.dumps({
-            "type": "discovery",
-            "node_id": self.node_id,
-            "grpc_port": self.node_port,
-            "device_capabilities": self.device_capabilities.to_dict()
-        }).encode('utf-8')
-
-        while True:
-            try:
-                if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}")
-                transport.sendto(message, ('<broadcast>', self.broadcast_port))
-                await asyncio.sleep(self.broadcast_interval)
-            except Exception as e:
-                print(f"Error in broadcast presence: {e}")
-                import traceback
-                print(traceback.format_exc())
-
-    async def on_listen_message(self, data, addr):
-        message = json.loads(data.decode('utf-8'))
-        if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}")
-        if message['type'] == 'discovery' and message['node_id'] != self.node_id:
-            peer_id = message['node_id']
-            peer_host = addr[0]
-            peer_port = message['grpc_port']
-            device_capabilities = DeviceCapabilities(**message['device_capabilities'])
-            if peer_id not in self.known_peers:
-                self.known_peers[peer_id] = (GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities), time.time(), time.time())
-                if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
-            self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
-
-    async def task_listen_for_peers(self):
-        await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port))
-        if DEBUG_DISCOVERY >= 2: print("Started listen task")
-
-    async def task_cleanup_peers(self):
-        while True:
-            try:
-                current_time = time.time()
-                peers_to_remove = [
-                    peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values() if
-                    (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
-                ]
-                if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
-                if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0: print(f"Cleaning up peers: {peers_to_remove}")
-                for peer_id in peers_to_remove:
-                    if peer_id in self.known_peers: del self.known_peers[peer_id]
-                    if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
-                await asyncio.sleep(self.broadcast_interval)
-            except Exception as e:
-                print(f"Error in cleanup peers: {e}")
-                import traceback
-                print(traceback.format_exc())
+          await asyncio.sleep(grace_period)
+          if DEBUG_DISCOVERY >= 2:
+            print(f"Waiting additional {wait_for_peers} seconds for more peers.")
+          wait_for_peers = 0
+        else:
+          if DEBUG_DISCOVERY >= 2:
+            print("No new peers discovered in the last grace period. Ending discovery process.")
+          break  # No new peers found in the grace period, we are done
+
+    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+
+  async def task_broadcast_presence(self):
+    transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: asyncio.DatagramProtocol(), local_addr=("0.0.0.0", 0), family=socket.AF_INET)
+    sock = transport.get_extra_info("socket")
+    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
+
+    message = json.dumps(
+      {
+        "type": "discovery",
+        "node_id": self.node_id,
+        "grpc_port": self.node_port,
+        "device_capabilities": self.device_capabilities.to_dict(),
+      }
+    ).encode("utf-8")
+
+    while True:
+      try:
+        if DEBUG_DISCOVERY >= 3:
+          print(f"Broadcast presence: {message}")
+        transport.sendto(message, ("<broadcast>", self.broadcast_port))
+        await asyncio.sleep(self.broadcast_interval)
+      except Exception as e:
+        print(f"Error in broadcast presence: {e}")
+        import traceback
+
+        print(traceback.format_exc())
+
+  async def on_listen_message(self, data, addr):
+    if not data:
+      return
+
+    decoded_data = data.decode("utf-8", errors="ignore")
+
+    # Check if the decoded data starts with a valid JSON character
+    if not (decoded_data.strip() and decoded_data.strip()[0] in "{["):
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Received invalid JSON data from {addr}: {decoded_data[:100]}")
+      return
+
+    try:
+      decoder = json.JSONDecoder(strict=False)
+      message = decoder.decode(decoded_data)
+    except json.JSONDecodeError as e:
+      if DEBUG_DISCOVERY >= 2:
+        print(f"Error decoding JSON data from {addr}: {e}")
+      return
+
+    if DEBUG_DISCOVERY >= 2:
+      print(f"received from peer {addr}: {message}")
+
+    if message["type"] == "discovery" and message["node_id"] != self.node_id:
+      peer_id = message["node_id"]
+      peer_host = addr[0]
+      peer_port = message["grpc_port"]
+      device_capabilities = DeviceCapabilities(**message["device_capabilities"])
+      if peer_id not in self.known_peers:
+        self.known_peers[peer_id] = (
+          GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities),
+          time.time(),
+          time.time(),
+        )
+        if DEBUG_DISCOVERY >= 2:
+          print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}")
+      self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
+
+  async def task_listen_for_peers(self):
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
+    if DEBUG_DISCOVERY >= 2:
+      print("Started listen task")
+
+  async def task_cleanup_peers(self):
+    while True:
+      try:
+        current_time = time.time()
+        peers_to_remove = [
+          peer_handle.id()
+          for peer_handle, connected_at, last_seen in self.known_peers.values()
+          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout
+        ]
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:",
+            {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()},
+          )
+        if DEBUG_DISCOVERY >= 2 and len(peers_to_remove) > 0:
+          print(f"Cleaning up peers: {peers_to_remove}")
+        for peer_id in peers_to_remove:
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+          if DEBUG_DISCOVERY >= 2:
+            print(f"Removed peer {peer_id} due to inactivity.")
+        await asyncio.sleep(self.broadcast_interval)
+      except Exception as e:
+        print(f"Error in cleanup peers: {e}")
+        import traceback
+
+        print(traceback.format_exc())

+ 94 - 81
exo/networking/grpc/grpc_peer_handle.py

@@ -11,85 +11,98 @@ from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.device_capabilities import DeviceCapabilities
 
 
+
 class GRPCPeerHandle(PeerHandle):
 class GRPCPeerHandle(PeerHandle):
-    def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
-        self._id = id
-        self.address = address
-        self._device_capabilities = device_capabilities
-        self.channel = None
-        self.stub = None
-
-    def id(self) -> str:
-        return self._id
-
-    def device_capabilities(self) -> DeviceCapabilities:
-        return self._device_capabilities
-
-    async def connect(self):
-        self.channel = grpc.aio.insecure_channel(self.address, options=[
-            ('grpc.max_metadata_size', 32*1024*1024)
-        ])
-        self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
-
-    async def is_connected(self) -> bool:
-        return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
-
-    async def disconnect(self):
-        if self.channel:
-            await self.channel.close()
-        self.channel = None
-        self.stub = None
-
-    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
-        request = node_service_pb2.PromptRequest(prompt=prompt, shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers), request_id=request_id, inference_state=inference_state)
-        response = await self.stub.SendPrompt(request)
-
-        if not response.tensor_data or not response.shape or not response.dtype:
-            return None
-
-        return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-
-    async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
-        request = node_service_pb2.TensorRequest(
-            shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers),
-            tensor = node_service_pb2.Tensor(
-                tensor_data=tensor.tobytes(),
-                shape=tensor.shape,
-                dtype=str(tensor.dtype)
-            ),
-            request_id=request_id,
-            inference_state=inference_state
-        )
-        response = await self.stub.SendTensor(request)
-
-        if not response.tensor_data or not response.shape or not response.dtype:
-            return None
-
-        return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
-        response = await self.stub.GetInferenceResult(request)
-        if response.tensor is None:
-            return None, response.is_finished
-        return np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape), response.is_finished
-
-    async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
-        request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
-        response = await self.stub.CollectTopology(request)
-        topology = Topology()
-        for node_id, capabilities in response.nodes.items():
-            device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
-            topology.update_node(node_id, device_capabilities)
-        for node_id, peers in response.peer_graph.items():
-            for peer_id in peers.peer_ids:
-                topology.add_edge(node_id, peer_id)
-        return topology
-
-    async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-        request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
-        await self.stub.SendResult(request)
-
-    async def send_opaque_status(self, request_id: str, status: str) -> None:
-        request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
-        await self.stub.SendOpaqueStatus(request)
+  def __init__(self, _id: str, address: str, device_capabilities: DeviceCapabilities):
+    self._id = _id
+    self.address = address
+    self._device_capabilities = device_capabilities
+    self.channel = None
+    self.stub = None
+
+  def id(self) -> str:
+    return self._id
+
+  def device_capabilities(self) -> DeviceCapabilities:
+    return self._device_capabilities
+
+  async def connect(self):
+    self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32 * 1024 * 1024)])
+    self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
+
+  async def is_connected(self) -> bool:
+    return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
+
+  async def disconnect(self):
+    if self.channel:
+      await self.channel.close()
+    self.channel = None
+    self.stub = None
+
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+    request = node_service_pb2.PromptRequest(
+      prompt=prompt,
+      shard=node_service_pb2.Shard(
+        model_id=shard.model_id,
+        start_layer=shard.start_layer,
+        end_layer=shard.end_layer,
+        n_layers=shard.n_layers,
+      ),
+      request_id=request_id,
+      inference_state=inference_state,
+    )
+    response = await self.stub.SendPrompt(request)
+
+    if not response.tensor_data or not response.shape or not response.dtype:
+      return None
+
+    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+    request = node_service_pb2.TensorRequest(
+      shard=node_service_pb2.Shard(
+        model_id=shard.model_id,
+        start_layer=shard.start_layer,
+        end_layer=shard.end_layer,
+        n_layers=shard.n_layers,
+      ),
+      tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
+      request_id=request_id,
+      inference_state=inference_state,
+    )
+    response = await self.stub.SendTensor(request)
+
+    if not response.tensor_data or not response.shape or not response.dtype:
+      return None
+
+    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
+    response = await self.stub.GetInferenceResult(request)
+    if response.tensor is None:
+      return None, response.is_finished
+    return (
+      np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
+      response.is_finished,
+    )
+
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+    request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
+    response = await self.stub.CollectTopology(request)
+    topology = Topology()
+    for node_id, capabilities in response.nodes.items():
+      device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
+      topology.update_node(node_id, device_capabilities)
+    for node_id, peers in response.peer_graph.items():
+      for peer_id in peers.peer_ids:
+        topology.add_edge(node_id, peer_id)
+    return topology
+
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
+    await self.stub.SendResult(request)
+
+  async def send_opaque_status(self, request_id: str, status: str) -> None:
+    request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
+    await self.stub.SendOpaqueStatus(request)

+ 94 - 65
exo/networking/grpc/grpc_server.py

@@ -8,78 +8,107 @@ from exo import DEBUG
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 from exo.orchestration import Node
 
 
+
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
 class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
-    def __init__(self, node: Node, host: str, port: int):
-        self.node = node
-        self.host = host
-        self.port = port
-        self.server = None
+  def __init__(self, node: Node, host: str, port: int):
+    self.node = node
+    self.host = host
+    self.port = port
+    self.server = None
 
 
-    async def start(self) -> None:
-        self.server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10), options=[
-            ('grpc.max_metadata_size', 32*1024*1024),
-            ('grpc.max_send_message_length', 128 * 1024 * 1024),
-            ('grpc.max_receive_message_length', 128 * 1024 * 1024),
-        ])
-        node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
-        listen_addr = f'{self.host}:{self.port}'
-        self.server.add_insecure_port(listen_addr)
-        await self.server.start()
-        if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
+  async def start(self) -> None:
+    self.server = grpc.aio.server(
+      futures.ThreadPoolExecutor(max_workers=10),
+      options=[
+        ("grpc.max_metadata_size", 32 * 1024 * 1024),
+        ("grpc.max_send_message_length", 128 * 1024 * 1024),
+        ("grpc.max_receive_message_length", 128 * 1024 * 1024),
+      ],
+    )
+    node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
+    listen_addr = f"{self.host}:{self.port}"
+    self.server.add_insecure_port(listen_addr)
+    await self.server.start()
+    if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
 
 
-    async def stop(self) -> None:
-        if self.server:
-            await self.server.stop(grace=5)
-            await self.server.wait_for_termination()
-            if DEBUG >= 1: print("Server stopped and all connections are closed")
+  async def stop(self) -> None:
+    if self.server:
+      await self.server.stop(grace=5)
+      await self.server.wait_for_termination()
+      if DEBUG >= 1: print("Server stopped and all connections are closed")
 
 
-    async def SendPrompt(self, request, context):
-        shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
-        prompt = request.prompt
-        request_id = request.request_id
-        result = await self.node.process_prompt(shard, prompt, request_id)
-        if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
-        tensor_data = result.tobytes() if result is not None else None
-        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+  async def SendPrompt(self, request, context):
+    shard = Shard(
+      model_id=request.shard.model_id,
+      start_layer=request.shard.start_layer,
+      end_layer=request.shard.end_layer,
+      n_layers=request.shard.n_layers,
+    )
+    prompt = request.prompt
+    request_id = request.request_id
+    result = await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 2: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
+    tensor_data = result.tobytes() if result is not None else None
+    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
 
-    async def SendTensor(self, request, context):
-        shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
-        tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
-        request_id = request.request_id
-        inference_state = request.inference_state
+  async def SendTensor(self, request, context):
+    shard = Shard(
+      model_id=request.shard.model_id,
+      start_layer=request.shard.start_layer,
+      end_layer=request.shard.end_layer,
+      n_layers=request.shard.n_layers,
+    )
+    tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
+    request_id = request.request_id
+    inference_state = request.inference_state
 
 
-        result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
-        if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
-        tensor_data = result.tobytes() if result is not None else None
-        return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
+    if DEBUG >= 2: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
+    tensor_data = result.tobytes() if result is not None else None
+    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
 
-    async def GetInferenceResult(self, request, context):
-        request_id = request.request_id
-        result = await self.node.get_inference_result(request_id)
-        if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
-        tensor_data = result[0].tobytes() if result[0] is not None else None
-        return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), is_finished=result[1]) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
+  async def GetInferenceResult(self, request, context):
+    request_id = request.request_id
+    result = await self.node.get_inference_result(request_id)
+    if DEBUG >= 2: print(f"GetInferenceResult {request_id=}: {result}")
+    tensor_data = result[0].tobytes() if result[0] is not None else None
+    return (
+      node_service_pb2.InferenceResult(
+        tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)),
+        is_finished=result[1],
+      )
+      if result[0] is not None
+      else node_service_pb2.InferenceResult(is_finished=result[1])
+    )
 
 
-    async def CollectTopology(self, request, context):
-        max_depth = request.max_depth
-        visited = set(request.visited)
-        topology = await self.node.collect_topology(visited, max_depth)
-        nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory, flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8)) for node_id, cap in topology.nodes.items()}
-        peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
-        if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
-        return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
+  async def CollectTopology(self, request, context):
+    max_depth = request.max_depth
+    visited = set(request.visited)
+    topology = await self.node.collect_topology(visited, max_depth)
+    nodes = {
+      node_id: node_service_pb2.DeviceCapabilities(
+        model=cap.model,
+        chip=cap.chip,
+        memory=cap.memory,
+        flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
+      )
+      for node_id, cap in topology.nodes.items()
+    }
+    peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()}
+    if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
+    return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 
 
-    async def SendResult(self, request, context):
-        request_id = request.request_id
-        result = request.result
-        is_finished = request.is_finished
-        if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
-        self.node.on_token.trigger_all(request_id, result, is_finished)
-        return node_service_pb2.Empty()
+  async def SendResult(self, request, context):
+    request_id = request.request_id
+    result = request.result
+    is_finished = request.is_finished
+    if DEBUG >= 2: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
+    self.node.on_token.trigger_all(request_id, result, is_finished)
+    return node_service_pb2.Empty()
 
 
-    async def SendOpaqueStatus(self, request, context):
-        request_id = request.request_id
-        status = request.status
-        if DEBUG >= 2: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
-        self.node.on_opaque_status.trigger_all(request_id, status)
-        return node_service_pb2.Empty()
+  async def SendOpaqueStatus(self, request, context):
+    request_id = request.request_id
+    status = request.status
+    if DEBUG >= 2: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
+    self.node.on_opaque_status.trigger_all(request_id, status)
+    return node_service_pb2.Empty()

+ 14 - 13
exo/networking/grpc/test_grpc_discovery.py

@@ -2,20 +2,21 @@ import asyncio
 import unittest
 import unittest
 from .grpc_discovery import GRPCDiscovery
 from .grpc_discovery import GRPCDiscovery
 
 
+
 class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
 class TestGRPCDiscovery(unittest.IsolatedAsyncioTestCase):
-    async def asyncSetUp(self):
-        self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
-        self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
-        await self.node1.start()
-        await self.node2.start()
+  async def asyncSetUp(self):
+    self.node1 = GRPCDiscovery("node1", 50051, 5678, 5679)
+    self.node2 = GRPCDiscovery("node2", 50052, 5679, 5678)
+    await self.node1.start()
+    await self.node2.start()
 
 
-    async def asyncTearDown(self):
-        await self.node1.stop()
-        await self.node2.stop()
+  async def asyncTearDown(self):
+    await self.node1.stop()
+    await self.node2.stop()
 
 
-    async def test_discovery(self):
-        await asyncio.sleep(4)
+  async def test_discovery(self):
+    await asyncio.sleep(4)
 
 
-        # Check discovered peers
-        print("Node1 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
-        print("Node2 Peers:", ', '.join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))
+    # Check discovered peers
+    print("Node1 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node1.known_peers.items()]))
+    print("Node2 Peers:", ", ".join([f"{peer_id}: {peer}" for peer_id, peer in self.node2.known_peers.items()]))

+ 31 - 30
exo/networking/peer_handle.py

@@ -5,43 +5,44 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 
 
+
 class PeerHandle(ABC):
 class PeerHandle(ABC):
-    @abstractmethod
-    def id(self) -> str:
-        pass
+  @abstractmethod
+  def id(self) -> str:
+    pass
 
 
-    @abstractmethod
-    def device_capabilities(self) -> DeviceCapabilities:
-        pass
+  @abstractmethod
+  def device_capabilities(self) -> DeviceCapabilities:
+    pass
 
 
-    @abstractmethod
-    async def connect(self) -> None:
-        pass
+  @abstractmethod
+  async def connect(self) -> None:
+    pass
 
 
-    @abstractmethod
-    async def is_connected(self) -> bool:
-        pass
+  @abstractmethod
+  async def is_connected(self) -> bool:
+    pass
 
 
-    @abstractmethod
-    async def disconnect(self) -> None:
-        pass
+  @abstractmethod
+  async def disconnect(self) -> None:
+    pass
 
 
-    @abstractmethod
-    async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
-        pass
+  @abstractmethod
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+    pass
 
 
-    @abstractmethod
-    async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
-        pass
+  @abstractmethod
+  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+    pass
 
 
-    @abstractmethod
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        pass
+  @abstractmethod
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    pass
 
 
-    @abstractmethod
-    async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
-        pass
+  @abstractmethod
+  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
+    pass
 
 
-    @abstractmethod
-    async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-        pass
+  @abstractmethod
+  async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    pass

+ 7 - 6
exo/networking/server.py

@@ -1,10 +1,11 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 
 
+
 class Server(ABC):
 class Server(ABC):
-    @abstractmethod
-    async def start(self) -> None:
-        pass
+  @abstractmethod
+  async def start(self) -> None:
+    pass
 
 
-    @abstractmethod
-    async def stop(self) -> None:
-        pass
+  @abstractmethod
+  async def stop(self) -> None:
+    pass

+ 40 - 39
exo/orchestration/node.py

@@ -1,46 +1,47 @@
-from typing import Optional, Tuple, List, Callable
+from typing import Optional, Tuple, List
 import numpy as np
 import numpy as np
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 
 
+
 class Node(ABC):
 class Node(ABC):
-    @abstractmethod
-    async def start(self, wait_for_peers: int = 0) -> None:
-        pass
-
-    @abstractmethod
-    async def stop(self) -> None:
-        pass
-
-    @abstractmethod
-    async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        pass
-
-    @abstractmethod
-    async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        pass
-
-    @abstractmethod
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        pass
-
-    @abstractmethod
-    async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
-        pass
-
-    @property
-    @abstractmethod
-    def current_topology(self) -> Topology:
-        pass
-
-    @property
-    @abstractmethod
-    def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
-        pass
-
-    @property
-    @abstractmethod
-    def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
-        pass
+  @abstractmethod
+  async def start(self, wait_for_peers: int = 0) -> None:
+    pass
+
+  @abstractmethod
+  async def stop(self) -> None:
+    pass
+
+  @abstractmethod
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    pass
+
+  @abstractmethod
+  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    pass
+
+  @abstractmethod
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    pass
+
+  @abstractmethod
+  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
+    pass
+
+  @property
+  @abstractmethod
+  def current_topology(self) -> Topology:
+    pass
+
+  @property
+  @abstractmethod
+  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+    pass
+
+  @property
+  @abstractmethod
+  def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
+    pass

+ 369 - 266
exo/orchestration/standard_node.py

@@ -14,271 +14,374 @@ from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 
 
+
 class StandardNode(Node):
 class StandardNode(Node):
-    def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 256, chatgpt_api_endpoint: Optional[str] = None, web_chat_url: Optional[str] = None, disable_tui: Optional[bool] = False):
-        self.id = id
-        self.inference_engine = inference_engine
-        self.server = server
-        self.discovery = discovery
-        self.partitioning_strategy = partitioning_strategy
-        self.peers: List[PeerHandle] = {}
-        self.topology: Topology = Topology()
-        self.device_capabilities = device_capabilities()
-        self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
-        self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) if not disable_tui else None
-        self.max_generate_tokens = max_generate_tokens
-        self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
-        self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
-        self._on_opaque_status.register("node_status").on_next(self.on_node_status)
-
-    def on_node_status(self, request_id, opaque_status):
-        try:
-            status_data = json.loads(opaque_status)
-            if status_data.get("type", "") == "node_status":
-                if status_data.get("status", "").startswith("start_"):
-                    self.current_topology.active_node_id = status_data.get("node_id")
-                elif status_data.get("status", "").startswith("end_"):
-                    if status_data.get("node_id") == self.current_topology.active_node_id:
-                        self.current_topology.active_node_id = None
-            if self.topology_viz: self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
-        except json.JSONDecodeError:
-            pass
-
-    async def start(self, wait_for_peers: int = 0) -> None:
-        await self.server.start()
-        await self.discovery.start()
-        await self.update_peers(wait_for_peers)
+  def __init__(
+    self,
+    _id: str,
+    server: Server,
+    inference_engine: InferenceEngine,
+    discovery: Discovery,
+    partitioning_strategy: PartitioningStrategy = None,
+    max_generate_tokens: int = 256,
+    chatgpt_api_endpoint: Optional[str] = None,
+    web_chat_url: Optional[str] = None,
+    disable_tui: Optional[bool] = False,
+  ):
+    self.id = _id
+    self.inference_engine = inference_engine
+    self.server = server
+    self.discovery = discovery
+    self.partitioning_strategy = partitioning_strategy
+    self.peers: List[PeerHandle] = {}
+    self.topology: Topology = Topology()
+    self.device_capabilities = device_capabilities()
+    self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.topology_viz = TopologyViz(chatgpt_api_endpoint=chatgpt_api_endpoint, web_chat_url=web_chat_url) if not disable_tui else None
+    self.max_generate_tokens = max_generate_tokens
+    self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
+    self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
+    self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+
+  def on_node_status(self, request_id, opaque_status):
+    try:
+      status_data = json.loads(opaque_status)
+      if status_data.get("type", "") == "node_status":
+        if status_data.get("status", "").startswith("start_"):
+          self.current_topology.active_node_id = status_data.get("node_id")
+        elif status_data.get("status", "").startswith("end_"):
+          if status_data.get("node_id") == self.current_topology.active_node_id:
+            self.current_topology.active_node_id = None
+      if self.topology_viz:
+        self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+    except json.JSONDecodeError:
+      pass
+
+  async def start(self, wait_for_peers: int = 0) -> None:
+    await self.server.start()
+    await self.discovery.start()
+    await self.update_peers(wait_for_peers)
+    await self.collect_topology()
+    if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+    asyncio.create_task(self.periodic_topology_collection(5))
+
+  async def stop(self) -> None:
+    await self.discovery.stop()
+    await self.server.stop()
+
+  async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    shard = self.get_current_shard(base_shard)
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps(
+          {
+            "type": "node_status",
+            "node_id": self.id,
+            "status": "start_process_prompt",
+            "base_shard": base_shard.to_dict(),
+            "shard": shard.to_dict(),
+            "prompt": prompt,
+            "inference_state": inference_state,
+            "request_id": request_id,
+          }
+        ),
+      )
+    )
+    start_time = time.perf_counter_ns()
+    resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
+    end_time = time.perf_counter_ns()
+    elapsed_time_ns = end_time - start_time
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps(
+          {
+            "type": "node_status",
+            "node_id": self.id,
+            "status": "end_process_prompt",
+            "base_shard": base_shard.to_dict(),
+            "shard": shard.to_dict(),
+            "prompt": prompt,
+            "inference_state": inference_state,
+            "request_id": request_id,
+            "elapsed_time_ns": elapsed_time_ns,
+            "result_size": resp.size if resp is not None else 0,
+          }
+        ),
+      )
+    )
+    return resp
+
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    if request_id is None:
+      request_id = str(uuid.uuid4())
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    shard = self.get_current_shard(base_shard)
+
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+    if shard.start_layer != 0:
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      await self.forward_to_next_shard(shard, prompt, request_id)
+      return
+
+    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
+    is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if is_finished:
+      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
+
+    if result.size == 1:
+      self.buffered_token_output[request_id][0].append(result.item())
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+
+    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+
+    if not is_finished:
+      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
+
+    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+
+  async def process_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None,
+  ) -> Optional[np.ndarray]:
+    shard = self.get_current_shard(base_shard)
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps(
+          {
+            "type": "node_status",
+            "node_id": self.id,
+            "status": "start_process_tensor",
+            "base_shard": base_shard.to_dict(),
+            "shard": shard.to_dict(),
+            "tensor_size": tensor.size,
+            "tensor_shape": tensor.shape,
+            "request_id": request_id,
+            "inference_state": inference_state,
+          }
+        ),
+      )
+    )
+    start_time = time.perf_counter_ns()
+    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
+    end_time = time.perf_counter_ns()
+    elapsed_time_ns = end_time - start_time
+    asyncio.create_task(
+      self.broadcast_opaque_status(
+        request_id,
+        json.dumps(
+          {
+            "type": "node_status",
+            "node_id": self.id,
+            "status": "end_process_tensor",
+            "base_shard": base_shard.to_dict(),
+            "shard": shard.to_dict(),
+            "request_id": request_id,
+            "elapsed_time_ns": elapsed_time_ns,
+            "result_size": resp.size if resp is not None else 0,
+          }
+        ),
+      )
+    )
+    return resp
+
+  async def _process_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: Optional[str] = None,
+    inference_state: Optional[str] = None,
+  ) -> Optional[np.ndarray]:
+    if request_id is None:
+      request_id = str(uuid.uuid4())
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    shard = self.get_current_shard(base_shard)
+
+    try:
+      if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
+      result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
+      is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+      if is_finished:
+        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
+
+      if result.size == 1:  # we got a new token out
+        self.buffered_token_output[request_id][0].append(result.item())
+        self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+
+      if not is_finished:
+        asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
+
+      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    except Exception as e:
+      print(f"Error processing tensor for shard {shard}: {e}")
+      import traceback
+
+      traceback.print_exc()
+      return None
+
+  async def forward_to_next_shard(
+    self,
+    base_shard: Shard,
+    tensor_or_prompt: Union[np.ndarray, str],
+    request_id: str,
+    inference_state: Optional[str] = None,
+  ) -> None:
+    if not self.partitioning_strategy:
+      if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
+      return
+    shard = self.get_current_shard(base_shard)
+
+    partitions = self.partitioning_strategy.partition(self.topology)
+    shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
+    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+    if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
+    if current_partition_index is not None:
+      next_partition_index = (current_partition_index + 1) % len(partitions)
+      next_partition: Partition = partitions[next_partition_index]
+      next_shard = shards[next_partition_index]
+      if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
+
+      if next_partition.node_id == self.id:
+        if isinstance(tensor_or_prompt, np.ndarray):
+          await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+        else:
+          await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
+        return
+
+      target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {next_partition} not found")
+
+      if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
+
+      if isinstance(tensor_or_prompt, np.ndarray):
+        await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+      else:
+        await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
+
+  def get_current_shard(self, base_shard: Shard) -> Shard:
+    partitions = self.partitioning_strategy.partition(self.topology)
+    shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
+    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+    if current_partition_index is None:
+      raise ValueError(f"No current partition found for node: {self.id}")
+    return shards[current_partition_index]
+
+  async def update_peers(self, wait_for_peers: int = 0) -> None:
+    self.peers = await self.discovery.discover_peers(wait_for_peers)
+    if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
+    if DEBUG >= 2: print("Connecting to new peers...")
+    for peer in self.peers:
+      is_connected = await peer.is_connected()
+      if DEBUG >= 2 and is_connected:
+        print(f"Already connected to {peer.id()}: {is_connected}")
+      if not is_connected:
+        await peer.connect()
+        if DEBUG >= 0: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
+
+  async def periodic_topology_collection(self, interval: int):
+    while True:
+      await asyncio.sleep(interval)
+      try:
+        await self.update_peers()
         await self.collect_topology()
         await self.collect_topology()
-        if DEBUG >= 2: print(f"Collected topology: {self.topology}")
-        asyncio.create_task(self.periodic_topology_collection(5))
-
-    async def stop(self) -> None:
-        await self.discovery.stop()
-        await self.server.stop()
-
-    async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        shard = self.get_current_shard(base_shard)
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_prompt", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id})))
-        start_time = time.perf_counter_ns()
-        resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
-        end_time = time.perf_counter_ns()
-        elapsed_time_ns = end_time - start_time
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_prompt", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "prompt": prompt, "inference_state": inference_state, "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
-        return resp
-
-    async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        if request_id is None:
-            request_id = str(uuid.uuid4())
-        if request_id not in self.buffered_token_output:
-            self.buffered_token_output[request_id] = ([], False)
-        shard = self.get_current_shard(base_shard)
-
-        if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
-        if shard.start_layer != 0:
-            if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-            await self.forward_to_next_shard(shard, prompt, request_id)
-            return
-
-        result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
-        is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-        if is_finished:
-            self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-        asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
-
-        if result.size == 1:
-            self.buffered_token_output[request_id][0].append(result.item())
-            self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-
-        if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-        if not is_finished:
-            asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-        return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
-
-    async def process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        shard = self.get_current_shard(base_shard)
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "start_process_tensor", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "tensor_size": tensor.size, "tensor_shape": tensor.shape, "request_id": request_id, "inference_state": inference_state})))
-        start_time = time.perf_counter_ns()
-        resp = await self._process_tensor(shard, tensor, request_id, inference_state)
-        end_time = time.perf_counter_ns()
-        elapsed_time_ns = end_time - start_time
-        asyncio.create_task(self.broadcast_opaque_status(request_id, json.dumps({"type": "node_status", "node_id": self.id, "status": "end_process_tensor", "base_shard": base_shard.to_dict(), "shard": shard.to_dict(), "request_id": request_id, "elapsed_time_ns": elapsed_time_ns, "result_size": resp.size if resp is not None else 0})))
-        return resp
-
-    async def _process_tensor(self, base_shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
-        if request_id is None:
-            request_id = str(uuid.uuid4())
-        if request_id not in self.buffered_token_output:
-            self.buffered_token_output[request_id] = ([], False)
-        shard = self.get_current_shard(base_shard)
-
-        try:
-            if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-            result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-            is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-            if is_finished:
-                self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-            asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
-
-            if result.size == 1:  # we got a new token out
-                self.buffered_token_output[request_id][0].append(result.item())
-                self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-            if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-            if not is_finished:
-                asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-            return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
-        except Exception as e:
-            print(f"Error processing tensor for shard {shard}: {e}")
-            import traceback
-            traceback.print_exc()
-            return None
-
-    async def forward_to_next_shard(self, base_shard: Shard, tensor_or_prompt: Union[np.ndarray, str], request_id: str, inference_state: Optional[str] = None) -> None:
-        if not self.partitioning_strategy:
-            if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-            return
-        shard = self.get_current_shard(base_shard)
-
-        partitions = self.partitioning_strategy.partition(self.topology)
-        shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
-        current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-        if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
-        if current_partition_index is not None:
-            next_partition_index = (current_partition_index + 1) % len(partitions)
-            next_partition: Partition = partitions[next_partition_index]
-            next_shard = shards[next_partition_index]
-            if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
-
-            if next_partition.node_id == self.id:
-                if isinstance(tensor_or_prompt, np.ndarray):
-                    await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-                else:
-                    await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-                return
-
-            target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-            if not target_peer:
-                raise ValueError(f"Peer for {next_partition} not found")
-
-            if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
-
-            if isinstance(tensor_or_prompt, np.ndarray):
-                await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
-            else:
-                await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
-
-    def get_current_shard(self, base_shard: Shard) -> Shard:
-        partitions = self.partitioning_strategy.partition(self.topology)
-        shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
-        current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-        if current_partition_index is None:
-            raise ValueError(f"No current partition found for node: {self.id}")
-        return shards[current_partition_index]
-
-    async def update_peers(self, wait_for_peers: int = 0) -> None:
-        self.peers = await self.discovery.discover_peers(wait_for_peers)
-        if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
-        if DEBUG >= 2: print("Connecting to new peers...")
-        for peer in self.peers:
-            is_connected = await peer.is_connected()
-            if DEBUG >= 2 and is_connected: print(f"Already connected to {peer.id()}: {is_connected}")
-            if not is_connected:
-                await peer.connect()
-                if DEBUG >= 0: print(f"Connected to peer {peer.device_capabilities()} ({peer.id()=})")
-
-    async def periodic_topology_collection(self, interval: int):
-        while True:
-            await asyncio.sleep(interval)
-            try:
-                await self.update_peers()
-                await self.collect_topology()
-            except Exception as e:
-                print(f"Error collecting topology: {e}")
-
-            if DEBUG >= 2: print("Topology collection task executed.")
-            if DEBUG >= 2: print(f"Current topology: {self.topology}")
-
-    async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
-        if request_id not in self.buffered_token_output:
-            return None, False
-        return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
-
-    async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
-        next_topology = Topology()
-        next_topology.update_node(self.id, self.device_capabilities)
-
-        if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
-
-        prev_visited = visited.copy()
-        visited.update(p.id() for p in self.peers)
-
-        for peer in self.peers:
-            next_topology.update_node(peer.id(), peer.device_capabilities())
-            next_topology.add_edge(self.id, peer.id())
-
-            if peer.id() in prev_visited:
-                if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
-                continue
-
-            if max_depth <= 0:
-                if DEBUG >= 2: print(f"Max depth reached. Skipping...")
-                continue
-
-            try:
-                other_topology = await peer.collect_topology(visited, max_depth = max_depth - 1)
-                if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
-                self.topology.merge(other_topology)
-            except Exception as e:
-                print(f"Error collecting topology from {peer.id()}: {e}")
-
-        next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
-        self.topology = next_topology
-        if self.topology_viz: self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
-        return next_topology
-
-    @property
-    def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
-        return self._on_token
-
-    @property
-    def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
-        return self._on_opaque_status
-
-    def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
-        if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
-        self.on_token.trigger_all(request_id, tokens, is_finished)
-
-    async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
-        async def send_result_to_peer(peer):
-            try:
-                await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
-            except asyncio.TimeoutError:
-                print(f"Timeout broadcasting result to {peer.id()}")
-            except Exception as e:
-                print(f"Error broadcasting result to {peer.id()}: {e}")
-                import traceback
-                traceback.print_exc()
-
-        await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
-
-    async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
-        async def send_status_to_peer(peer):
-            try:
-                await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
-            except asyncio.TimeoutError:
-                print(f"Timeout sending opaque status to {peer.id()}")
-            except Exception as e:
-                print(f"Error sending opaque status to {peer.id()}: {e}")
-                import traceback
-                traceback.print_exc()
-
-        await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
-        # in the case of opaque status, we also want to receive our own opaque statuses
-        self.on_opaque_status.trigger_all(request_id, status)
-
-    @property
-    def current_topology(self) -> Topology:
-        return self.topology
+      except Exception as e:
+        print(f"Error collecting topology: {e}")
+
+      if DEBUG >= 2: print("Topology collection task executed.")
+      if DEBUG >= 2: print(f"Current topology: {self.topology}")
+
+  async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+    if request_id not in self.buffered_token_output:
+      return None, False
+    return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
+
+  async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
+    next_topology = Topology()
+    next_topology.update_node(self.id, self.device_capabilities)
+
+    if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
+
+    prev_visited = visited.copy()
+    visited.update(p.id() for p in self.peers)
+
+    for peer in self.peers:
+      next_topology.update_node(peer.id(), peer.device_capabilities())
+      next_topology.add_edge(self.id, peer.id())
+
+      if peer.id() in prev_visited:
+        if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
+        continue
+
+      if max_depth <= 0:
+        if DEBUG >= 2: print("Max depth reached. Skipping...")
+        continue
+
+      try:
+        other_topology = await peer.collect_topology(visited, max_depth=max_depth - 1)
+        if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
+        self.topology.merge(other_topology)
+      except Exception as e:
+        print(f"Error collecting topology from {peer.id()}: {e}")
+
+    next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
+    self.topology = next_topology
+    if self.topology_viz:
+      self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
+    return next_topology
+
+  @property
+  def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+    return self._on_token
+
+  @property
+  def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
+    return self._on_opaque_status
+
+  def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
+    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
+    self.on_token.trigger_all(request_id, tokens, is_finished)
+
+  async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+    async def send_result_to_peer(peer):
+      try:
+        await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
+      except asyncio.TimeoutError:
+        print(f"Timeout broadcasting result to {peer.id()}")
+      except Exception as e:
+        print(f"Error broadcasting result to {peer.id()}: {e}")
+        import traceback
+
+        traceback.print_exc()
+
+    await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
+
+  async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
+    async def send_status_to_peer(peer):
+      try:
+        await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
+      except asyncio.TimeoutError:
+        print(f"Timeout sending opaque status to {peer.id()}")
+      except Exception as e:
+        print(f"Error sending opaque status to {peer.id()}: {e}")
+        import traceback
+
+        traceback.print_exc()
+
+    await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
+    # in the case of opaque status, we also want to receive our own opaque statuses
+    self.on_opaque_status.trigger_all(request_id, status)
+
+  @property
+  def current_topology(self) -> Topology:
+    return self.topology

+ 49 - 48
exo/orchestration/test_node.py

@@ -5,52 +5,53 @@ import numpy as np
 from .standard_node import StandardNode
 from .standard_node import StandardNode
 from exo.networking.peer_handle import PeerHandle
 from exo.networking.peer_handle import PeerHandle
 
 
+
 class TestNode(unittest.IsolatedAsyncioTestCase):
 class TestNode(unittest.IsolatedAsyncioTestCase):
-    def setUp(self):
-        self.mock_inference_engine = AsyncMock()
-        self.mock_server = AsyncMock()
-        self.mock_server.start = AsyncMock()
-        self.mock_server.stop = AsyncMock()
-        self.mock_discovery = AsyncMock()
-        self.mock_discovery.start = AsyncMock()
-        self.mock_discovery.stop = AsyncMock()
-        mock_peer1 = Mock(spec=PeerHandle)
-        mock_peer1.id.return_value = "peer1"
-        mock_peer2 = Mock(spec=PeerHandle)
-        mock_peer2.id.return_value = "peer2"
-        self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
-
-        self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
-
-    async def asyncSetUp(self):
-        await self.node.start()
-
-    async def asyncTearDown(self):
-        await self.node.stop()
-
-    async def test_node_initialization(self):
-        self.assertEqual(self.node.node_id, "test_node")
-        self.assertEqual(self.node.host, "localhost")
-        self.assertEqual(self.node.port, 50051)
-
-    async def test_node_start(self):
-        self.mock_server.start.assert_called_once_with("localhost", 50051)
-
-    async def test_node_stop(self):
-        await self.node.stop()
-        self.mock_server.stop.assert_called_once()
-
-    async def test_discover_and_connect_to_peers(self):
-        await self.node.discover_and_connect_to_peers()
-        self.assertEqual(len(self.node.peers), 2)
-        self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
-        self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
-
-    async def test_process_tensor_calls_inference_engine(self):
-        mock_peer = Mock()
-        self.node.peers = [mock_peer]
-
-        input_tensor = np.array([69, 1, 2])
-        await self.node.process_tensor(input_tensor, None)
-
-        self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)
+  def setUp(self):
+    self.mock_inference_engine = AsyncMock()
+    self.mock_server = AsyncMock()
+    self.mock_server.start = AsyncMock()
+    self.mock_server.stop = AsyncMock()
+    self.mock_discovery = AsyncMock()
+    self.mock_discovery.start = AsyncMock()
+    self.mock_discovery.stop = AsyncMock()
+    mock_peer1 = Mock(spec=PeerHandle)
+    mock_peer1.id.return_value = "peer1"
+    mock_peer2 = Mock(spec=PeerHandle)
+    mock_peer2.id.return_value = "peer2"
+    self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
+
+    self.node = StandardNode("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
+
+  async def asyncSetUp(self):
+    await self.node.start()
+
+  async def asyncTearDown(self):
+    await self.node.stop()
+
+  async def test_node_initialization(self):
+    self.assertEqual(self.node.node_id, "test_node")
+    self.assertEqual(self.node.host, "localhost")
+    self.assertEqual(self.node.port, 50051)
+
+  async def test_node_start(self):
+    self.mock_server.start.assert_called_once_with("localhost", 50051)
+
+  async def test_node_stop(self):
+    await self.node.stop()
+    self.mock_server.stop.assert_called_once()
+
+  async def test_discover_and_connect_to_peers(self):
+    await self.node.discover_and_connect_to_peers()
+    self.assertEqual(len(self.node.peers), 2)
+    self.assertIn("peer1", map(lambda p: p.id(), self.node.peers))
+    self.assertIn("peer2", map(lambda p: p.id(), self.node.peers))
+
+  async def test_process_tensor_calls_inference_engine(self):
+    mock_peer = Mock()
+    self.node.peers = [mock_peer]
+
+    input_tensor = np.array([69, 1, 2])
+    await self.node.process_tensor(input_tensor, None)
+
+    self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)

+ 19 - 18
exo/stats/metrics.py

@@ -1,28 +1,29 @@
 from exo.orchestration import Node
 from exo.orchestration import Node
 from prometheus_client import start_http_server, Counter, Histogram
 from prometheus_client import start_http_server, Counter, Histogram
 import json
 import json
-from typing import List
 
 
 # Create metrics to track time spent and requests made.
 # Create metrics to track time spent and requests made.
-PROCESS_PROMPT_COUNTER = Counter('process_prompt_total', 'Total number of prompts processed', ['node_id'])
-PROCESS_TENSOR_COUNTER = Counter('process_tensor_total', 'Total number of tensors processed', ['node_id'])
-PROCESS_TENSOR_TIME = Histogram('process_tensor_seconds', 'Time spent processing tensor', ['node_id'])
+PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
+PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
+PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])
+
 
 
 def start_metrics_server(node: Node, port: int):
 def start_metrics_server(node: Node, port: int):
-    start_http_server(port)
+  start_http_server(port)
 
 
-    def _on_opaque_status(request_id, opaque_status: str):
-        status_data = json.loads(opaque_status)
-        type = status_data.get("type", "")
-        node_id = status_data.get("node_id", "")
-        if type != "node_status": return
-        status = status_data.get("status", "")
+  def _on_opaque_status(request_id, opaque_status: str):
+    status_data = json.loads(opaque_status)
+    _type = status_data.get("type", "")
+    node_id = status_data.get("node_id", "")
+    if _type != "node_status":
+      return
+    status = status_data.get("status", "")
 
 
-        if status == "end_process_prompt":
-            PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
-        elif status == "end_process_tensor":
-            elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
-            PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
-            PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns / 1e9)  # Convert ns to seconds
+    if status == "end_process_prompt":
+      PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
+    elif status == "end_process_tensor":
+      elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
+      PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
+      PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns / 1e9)  # Convert ns to seconds
 
 
-    node.on_opaque_status.register("stats").on_next(_on_opaque_status)
+  node.on_opaque_status.register("stats").on_next(_on_opaque_status)

+ 33 - 30
exo/test_callbacks.py

@@ -2,46 +2,49 @@ import asyncio
 from typing import Any, Callable
 from typing import Any, Callable
 from exo.helpers import AsyncCallbackSystem, AsyncCallback
 from exo.helpers import AsyncCallbackSystem, AsyncCallback
 
 
+
 # Usage example
 # Usage example
 async def main() -> None:
 async def main() -> None:
-    callback_system = AsyncCallbackSystem[str, Any]()
+  callback_system = AsyncCallbackSystem[str, Any]()
+
+  # Register callbacks
+  callback1 = callback_system.register("callback1")
+  callback2 = callback_system.register("callback2")
+
+  def on_next_callback(name: str) -> Callable[..., None]:
+    def callback(*args: Any) -> None:
+      print(f"{name} received values: {args}")
 
 
-    # Register callbacks
-    callback1 = callback_system.register("callback1")
-    callback2 = callback_system.register("callback2")
+    return callback
 
 
-    def on_next_callback(name: str) -> Callable[..., None]:
-        def callback(*args: Any) -> None:
-            print(f"{name} received values: {args}")
-        return callback
+  callback1.on_next(on_next_callback("Callback1"))
+  callback2.on_next(on_next_callback("Callback2"))
 
 
-    callback1.on_next(on_next_callback("Callback1"))
-    callback2.on_next(on_next_callback("Callback2"))
+  async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None:
+    try:
+      result = await callback.wait(condition, timeout=2)
+      print(f"{name} wait completed with result: {result}")
+    except asyncio.TimeoutError:
+      print(f"{name} wait timed out")
 
 
-    async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None:
-        try:
-            result = await callback.wait(condition, timeout=2)
-            print(f"{name} wait completed with result: {result}")
-        except asyncio.TimeoutError:
-            print(f"{name} wait timed out")
+  # Trigger all callbacks at once
+  callback_system.trigger_all("Hello", 42, True)
 
 
-    # Trigger all callbacks at once
-    callback_system.trigger_all("Hello", 42, True)
+  # Wait for all callbacks with different conditions
+  await asyncio.gather(
+    wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0),
+    wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True),
+  )
 
 
-    # Wait for all callbacks with different conditions
-    await asyncio.gather(
-        wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0),
-        wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True)
-    )
+  # Trigger individual callback
+  callback_system.trigger("callback2", "World", -10, False)
 
 
-    # Trigger individual callback
-    callback_system.trigger("callback2", "World", -10, False)
+  # Demonstrate timeout
+  new_callback = callback_system.register("new_callback")
+  new_callback.on_next(on_next_callback("NewCallback"))
+  await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100)
 
 
-    # Demonstrate timeout
-    new_callback = callback_system.register("new_callback")
-    new_callback.on_next(on_next_callback("NewCallback"))
-    await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100)
+  callback_system.trigger("callback2", "World", 200, False)
 
 
-    callback_system.trigger("callback2", "World", 200, False)
 
 
 asyncio.run(main())
 asyncio.run(main())

+ 154 - 132
exo/topology/device_capabilities.py

@@ -5,153 +5,175 @@ import psutil
 
 
 TFLOPS = 1.00
 TFLOPS = 1.00
 
 
+
 @dataclass
 @dataclass
 class DeviceFlops:
 class DeviceFlops:
-    # units of TFLOPS
-    fp32: float
-    fp16: float
-    int8: float
+  # units of TFLOPS
+  fp32: float
+  fp16: float
+  int8: float
+
+  def __str__(self):
+    return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
 
 
-    def __str__(self):
-        return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
+  def to_dict(self):
+    return asdict(self)
 
 
-    def to_dict(self):
-        return asdict(self)
 
 
 @dataclass
 @dataclass
 class DeviceCapabilities:
 class DeviceCapabilities:
-    model: str
-    chip: str
-    memory: int
-    flops: DeviceFlops
-
-    def __str__(self):
-        return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
-
-    def __post_init__(self):
-        if isinstance(self.flops, dict):
-            self.flops = DeviceFlops(**self.flops)
-
-    def to_dict(self):
-        return {
-            'model': self.model,
-            'chip': self.chip,
-            'memory': self.memory,
-            'flops': self.flops.to_dict()
-        }
+  model: str
+  chip: str
+  memory: int
+  flops: DeviceFlops
+
+  def __str__(self):
+    return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
+
+  def __post_init__(self):
+    if isinstance(self.flops, dict):
+      self.flops = DeviceFlops(**self.flops)
+
+  def to_dict(self):
+    return {"model": self.model, "chip": self.chip, "memory": self.memory, "flops": self.flops.to_dict()}
+
 
 
 UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
 UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
 
 
 CHIP_FLOPS = {
 CHIP_FLOPS = {
-    # Source: https://www.cpu-monkey.com
-    # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
-    ### M chips
-    "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS),
-    "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS),
-    "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS),
-    "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS),
-    "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-    "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS),
-    "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
-    "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
-    "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-    "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
-    "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
-    "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-    ### A chips
-    "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
-    "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
-    "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS),
-    "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS),
-    "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS),
-    ### NVIDIA GPUs
-    #RTX 40 series
-    "Nvidia GeForce RTX 4090": DeviceFlops(fp32=82.58*TFLOPS, fp16=165.16*TFLOPS, int8=330.32*TFLOPS),
-    "Nvidia GeForce RTX 4080": DeviceFlops(fp32=48.74*TFLOPS, fp16=97.48*TFLOPS, int8=194.96*TFLOPS),
-    "Nvidia GeForce RTX 4080 Super": DeviceFlops(fp32=52.0*TFLOPS, fp16=104.0*TFLOPS, int8=208.0*TFLOPS),
-    "Nvidia GeForce RTX 4070 Ti Super": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
-    "Nvidia GeForce RTX 4070 Ti": DeviceFlops(fp32=39.43*TFLOPS, fp16=78.86*TFLOPS, int8=157.72*TFLOPS),
-    "Nvidia GeForce RTX 4070 Super": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
-    "Nvidia GeForce RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
-    "Nvidia GeForce RTX 4060 Ti 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
-    #RTX 30 series
-    "Nvidia GeForce RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
-    "Nvidia GeForce RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
-    "Nvidia GeForce RTX 3060 Ti": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
-    "Nvidia GeForce RTX 3070": DeviceFlops(fp32=20.3*TFLOPS, fp16=40.6*TFLOPS, int8=81.2*TFLOPS),
-    "Nvidia GeForce RTX 3070 Ti": DeviceFlops(fp32=21.8*TFLOPS, fp16=43.6*TFLOPS, int8=87.2*TFLOPS),
-    "Nvidia GeForce RTX 3080 (10 GB)": DeviceFlops(fp32=29.8*TFLOPS, fp16=59.6*TFLOPS, int8=119.2*TFLOPS),
-    "Nvidia GeForce RTX 3080 (12 GB)": DeviceFlops(fp32=30.6*TFLOPS, fp16=61.2*TFLOPS, int8=122.4*TFLOPS),
-    "Nvidia GeForce RTX 3080 Ti": DeviceFlops(fp32=34.1*TFLOPS, fp16=68.2*TFLOPS, int8=136.4*TFLOPS),
-    "Nvidia GeForce RTX 3090": DeviceFlops(fp32=35.6*TFLOPS, fp16=71.2*TFLOPS, int8=142.4*TFLOPS),
-    "Nvidia GeForce RTX 3090 Ti": DeviceFlops(fp32=40.0*TFLOPS, fp16=80.0*TFLOPS, int8=160.0*TFLOPS),
-        # ... add more devices if needed ...
-    ### AMD GPUs
-    # RX 6000 series
-    "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04*TFLOPS, fp16=46.08*TFLOPS, int8=92.16*TFLOPS),
-    "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74*TFLOPS, fp16=41.48*TFLOPS, int8=82.96*TFLOPS),
-    "AMD Radeon RX 6800": DeviceFlops(fp32=16.17*TFLOPS, fp16=32.34*TFLOPS, int8=64.68*TFLOPS),
-    "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21*TFLOPS, fp16=26.42*TFLOPS, int8=52.84*TFLOPS),
-    "AMD Radeon RX 6700": DeviceFlops(fp32=11.4*TFLOPS, fp16=22.8*TFLOPS, int8=45.6*TFLOPS),
-    "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6*TFLOPS, fp16=21.2*TFLOPS, int8=42.4*TFLOPS),
-    "AMD Radeon RX 6600": DeviceFlops(fp32=8.93*TFLOPS, fp16=17.86*TFLOPS, int8=35.72*TFLOPS),
-    "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77*TFLOPS, fp16=11.54*TFLOPS, int8=23.08*TFLOPS),
-    "AMD Radeon RX 6400": DeviceFlops(fp32=3.57*TFLOPS, fp16=7.14*TFLOPS, int8=14.28*TFLOPS),
-    # RX 7000 series
-    "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4*TFLOPS, fp16=122.8*TFLOPS, int8=245.6*TFLOPS),
-    "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4*TFLOPS, fp16=106.8*TFLOPS, int8=213.6*TFLOPS),
-    "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6*TFLOPS, fp16=85.2*TFLOPS, int8=170.4*TFLOPS),
-    "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2*TFLOPS, fp16=68.4*TFLOPS, int8=136.8*TFLOPS),
-    "AMD Radeon RX 7600": DeviceFlops(fp32=21.5*TFLOPS, fp16=43.0*TFLOPS, int8=86.0*TFLOPS),
-    "AMD Radeon RX 7500": DeviceFlops(fp32=16.2*TFLOPS, fp16=32.4*TFLOPS, int8=64.8*TFLOPS),
-    # ... add more devices if needed ...
-    ### Qualcomm embedded chips: TODO
+  # Source: https://www.cpu-monkey.com
+  # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
+  ### M chips
+  "Apple M1": DeviceFlops(fp32=2.29 * TFLOPS, fp16=4.58 * TFLOPS, int8=9.16 * TFLOPS),
+  "Apple M1 Pro": DeviceFlops(fp32=5.30 * TFLOPS, fp16=10.60 * TFLOPS, int8=21.20 * TFLOPS),
+  "Apple M1 Max": DeviceFlops(fp32=10.60 * TFLOPS, fp16=21.20 * TFLOPS, int8=42.40 * TFLOPS),
+  "Apple M1 Ultra": DeviceFlops(fp32=21.20 * TFLOPS, fp16=42.40 * TFLOPS, int8=84.80 * TFLOPS),
+  "Apple M2": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
+  "Apple M2 Pro": DeviceFlops(fp32=5.68 * TFLOPS, fp16=11.36 * TFLOPS, int8=22.72 * TFLOPS),
+  "Apple M2 Max": DeviceFlops(fp32=13.49 * TFLOPS, fp16=26.98 * TFLOPS, int8=53.96 * TFLOPS),
+  "Apple M2 Ultra": DeviceFlops(fp32=26.98 * TFLOPS, fp16=53.96 * TFLOPS, int8=107.92 * TFLOPS),
+  "Apple M3": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS),
+  "Apple M3 Pro": DeviceFlops(fp32=4.97 * TFLOPS, fp16=9.94 * TFLOPS, int8=19.88 * TFLOPS),
+  "Apple M4": DeviceFlops(fp32=3.55 * TFLOPS, fp16=7.10 * TFLOPS, int8=14.20 * TFLOPS),
+  ### A chips
+  "Apple A13 Bionic": DeviceFlops(fp32=0.69 * TFLOPS, fp16=1.38 * TFLOPS, int8=2.76 * TFLOPS),
+  "Apple A14 Bionic": DeviceFlops(fp32=0.75 * TFLOPS, fp16=1.50 * TFLOPS, int8=3.00 * TFLOPS),
+  "Apple A15 Bionic": DeviceFlops(fp32=1.37 * TFLOPS, fp16=2.74 * TFLOPS, int8=5.48 * TFLOPS),
+  "Apple A16 Bionic": DeviceFlops(fp32=1.79 * TFLOPS, fp16=3.58 * TFLOPS, int8=7.16 * TFLOPS),
+  "Apple A17 Pro": DeviceFlops(fp32=2.15 * TFLOPS, fp16=4.30 * TFLOPS, int8=8.60 * TFLOPS),
+  ### NVIDIA GPUs
+  # RTX 40 series
+  "Nvidia GeForce RTX 4090": DeviceFlops(fp32=82.58 * TFLOPS, fp16=165.16 * TFLOPS, int8=330.32 * TFLOPS),
+  "Nvidia GeForce RTX 4080": DeviceFlops(fp32=48.74 * TFLOPS, fp16=97.48 * TFLOPS, int8=194.96 * TFLOPS),
+  "Nvidia GeForce RTX 4080 Super": DeviceFlops(fp32=52.0 * TFLOPS, fp16=104.0 * TFLOPS, int8=208.0 * TFLOPS),
+  "Nvidia GeForce RTX 4070 Ti Super": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
+  "Nvidia GeForce RTX 4070 Ti": DeviceFlops(fp32=39.43 * TFLOPS, fp16=78.86 * TFLOPS, int8=157.72 * TFLOPS),
+  "Nvidia GeForce RTX 4070 Super": DeviceFlops(fp32=30.0 * TFLOPS, fp16=60.0 * TFLOPS, int8=120.0 * TFLOPS),
+  "Nvidia GeForce RTX 4070": DeviceFlops(fp32=29.0 * TFLOPS, fp16=58.0 * TFLOPS, int8=116.0 * TFLOPS),
+  "Nvidia GeForce RTX 4060 Ti 16GB": DeviceFlops(fp32=22.0 * TFLOPS, fp16=44.0 * TFLOPS, int8=88.0 * TFLOPS),
+  # RTX 30 series
+  "Nvidia GeForce RTX 3050": DeviceFlops(fp32=9.11 * TFLOPS, fp16=18.22 * TFLOPS, int8=36.44 * TFLOPS),
+  "Nvidia GeForce RTX 3060": DeviceFlops(fp32=13.0 * TFLOPS, fp16=26.0 * TFLOPS, int8=52.0 * TFLOPS),
+  "Nvidia GeForce RTX 3060 Ti": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
+  "Nvidia GeForce RTX 3070": DeviceFlops(fp32=20.3 * TFLOPS, fp16=40.6 * TFLOPS, int8=81.2 * TFLOPS),
+  "Nvidia GeForce RTX 3070 Ti": DeviceFlops(fp32=21.8 * TFLOPS, fp16=43.6 * TFLOPS, int8=87.2 * TFLOPS),
+  "Nvidia GeForce RTX 3080 (10 GB)": DeviceFlops(fp32=29.8 * TFLOPS, fp16=59.6 * TFLOPS, int8=119.2 * TFLOPS),
+  "Nvidia GeForce RTX 3080 (12 GB)": DeviceFlops(fp32=30.6 * TFLOPS, fp16=61.2 * TFLOPS, int8=122.4 * TFLOPS),
+  "Nvidia GeForce RTX 3080 Ti": DeviceFlops(fp32=34.1 * TFLOPS, fp16=68.2 * TFLOPS, int8=136.4 * TFLOPS),
+  "Nvidia GeForce RTX 3090": DeviceFlops(fp32=35.6 * TFLOPS, fp16=71.2 * TFLOPS, int8=142.4 * TFLOPS),
+  "Nvidia GeForce RTX 3090 Ti": DeviceFlops(fp32=40.0 * TFLOPS, fp16=80.0 * TFLOPS, int8=160.0 * TFLOPS),
+  # ... add more devices if needed ...
+  ### AMD GPUs
+  # RX 6000 series
+  "AMD Radeon RX 6900 XT": DeviceFlops(fp32=23.04 * TFLOPS, fp16=46.08 * TFLOPS, int8=92.16 * TFLOPS),
+  "AMD Radeon RX 6800 XT": DeviceFlops(fp32=20.74 * TFLOPS, fp16=41.48 * TFLOPS, int8=82.96 * TFLOPS),
+  "AMD Radeon RX 6800": DeviceFlops(fp32=16.17 * TFLOPS, fp16=32.34 * TFLOPS, int8=64.68 * TFLOPS),
+  "AMD Radeon RX 6700 XT": DeviceFlops(fp32=13.21 * TFLOPS, fp16=26.42 * TFLOPS, int8=52.84 * TFLOPS),
+  "AMD Radeon RX 6700": DeviceFlops(fp32=11.4 * TFLOPS, fp16=22.8 * TFLOPS, int8=45.6 * TFLOPS),
+  "AMD Radeon RX 6600 XT": DeviceFlops(fp32=10.6 * TFLOPS, fp16=21.2 * TFLOPS, int8=42.4 * TFLOPS),
+  "AMD Radeon RX 6600": DeviceFlops(fp32=8.93 * TFLOPS, fp16=17.86 * TFLOPS, int8=35.72 * TFLOPS),
+  "AMD Radeon RX 6500 XT": DeviceFlops(fp32=5.77 * TFLOPS, fp16=11.54 * TFLOPS, int8=23.08 * TFLOPS),
+  "AMD Radeon RX 6400": DeviceFlops(fp32=3.57 * TFLOPS, fp16=7.14 * TFLOPS, int8=14.28 * TFLOPS),
+  # RX 7000 series
+  "AMD Radeon RX 7900 XTX": DeviceFlops(fp32=61.4 * TFLOPS, fp16=122.8 * TFLOPS, int8=245.6 * TFLOPS),
+  "AMD Radeon RX 7900 XT": DeviceFlops(fp32=53.4 * TFLOPS, fp16=106.8 * TFLOPS, int8=213.6 * TFLOPS),
+  "AMD Radeon RX 7800 XT": DeviceFlops(fp32=42.6 * TFLOPS, fp16=85.2 * TFLOPS, int8=170.4 * TFLOPS),
+  "AMD Radeon RX 7700 XT": DeviceFlops(fp32=34.2 * TFLOPS, fp16=68.4 * TFLOPS, int8=136.8 * TFLOPS),
+  "AMD Radeon RX 7600": DeviceFlops(fp32=21.5 * TFLOPS, fp16=43.0 * TFLOPS, int8=86.0 * TFLOPS),
+  "AMD Radeon RX 7500": DeviceFlops(fp32=16.2 * TFLOPS, fp16=32.4 * TFLOPS, int8=64.8 * TFLOPS),
+  # ... add more devices if needed ...
+  ### Qualcomm embedded chips: TODO
 }
 }
 
 
+
 def device_capabilities() -> DeviceCapabilities:
 def device_capabilities() -> DeviceCapabilities:
-    if psutil.MACOS:
-        return mac_device_capabilities()
-    elif psutil.LINUX:
-        return linux_device_capabilities()
-    else:
-        return DeviceCapabilities(model=f"Unknown Device", chip=f"Unknown Chip", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
+  if psutil.MACOS:
+    return mac_device_capabilities()
+  elif psutil.LINUX:
+    return linux_device_capabilities()
+  else:
+    return DeviceCapabilities(
+      model="Unknown Device",
+      chip="Unknown Chip",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+
 
 
 def mac_device_capabilities() -> DeviceCapabilities:
 def mac_device_capabilities() -> DeviceCapabilities:
-    # Fetch the model of the Mac using system_profiler
-    model = subprocess.check_output(['system_profiler', 'SPHardwareDataType']).decode('utf-8')
-    model_line = next((line for line in model.split('\n') if "Model Name" in line), None)
-    model_id = model_line.split(': ')[1] if model_line else "Unknown Model"
-    chip_line = next((line for line in model.split('\n') if "Chip" in line), None)
-    chip_id = chip_line.split(': ')[1] if chip_line else "Unknown Chip"
-    memory_line = next((line for line in model.split('\n') if "Memory" in line), None)
-    memory_str = memory_line.split(': ')[1] if memory_line else "Unknown Memory"
-    memory_units = memory_str.split()
-    memory_value = int(memory_units[0])
-    if memory_units[1] == "GB":
-        memory = memory_value * 1024
-    else:
-        memory = memory_value
-
-    # Assuming static values for other attributes for demonstration
-    return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
+  # Fetch the model of the Mac using system_profiler
+  model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
+  model_line = next((line for line in model.split("\n") if "Model Name" in line), None)
+  model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
+  chip_line = next((line for line in model.split("\n") if "Chip" in line), None)
+  chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
+  memory_line = next((line for line in model.split("\n") if "Memory" in line), None)
+  memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
+  memory_units = memory_str.split()
+  memory_value = int(memory_units[0])
+  if memory_units[1] == "GB":
+    memory = memory_value * 1024
+  else:
+    memory = memory_value
+
+  # Assuming static values for other attributes for demonstration
+  return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
+
 
 
 def linux_device_capabilities() -> DeviceCapabilities:
 def linux_device_capabilities() -> DeviceCapabilities:
-    import psutil
-    from tinygrad import Device
-
-    if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
-    if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT=="GPU":
-        import pynvml, pynvml_utils
-        pynvml.nvmlInit()
-        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
-        gpu_name = pynvml.nvmlDeviceGetName(handle)
-        gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
-
-        if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
-
-        return DeviceCapabilities(model=f"Linux Box ({gpu_name})", chip=gpu_name, memory=gpu_memory_info.total // 2**20, flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)))
-    elif Device.DEFAULT == "AMD":
-        # TODO AMD support
-        return DeviceCapabilities(model="Linux Box (AMD)", chip="Unknown AMD", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
-    else:
-        return DeviceCapabilities(model=f"Linux Box (Device: {Device.DEFAULT})", chip=f"Unknown Chip (Device: {Device.DEFAULT})", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0))
+  import psutil
+  from tinygrad import Device
+
+  if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
+  if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT == "GPU":
+    import pynvml
+
+    pynvml.nvmlInit()
+    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+    gpu_name = pynvml.nvmlDeviceGetName(handle)
+    gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
+
+    if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
+
+    return DeviceCapabilities(
+      model=f"Linux Box ({gpu_name})",
+      chip=gpu_name,
+      memory=gpu_memory_info.total // 2**20,
+      flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+  elif Device.DEFAULT == "AMD":
+    # TODO AMD support
+    return DeviceCapabilities(
+      model="Linux Box (AMD)",
+      chip="Unknown AMD",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )
+  else:
+    return DeviceCapabilities(
+      model=f"Linux Box (Device: {Device.DEFAULT})",
+      chip=f"Unknown Chip (Device: {Device.DEFAULT})",
+      memory=psutil.virtual_memory().total // 2**20,
+      flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+    )

+ 24 - 21
exo/topology/partitioning_strategy.py

@@ -1,37 +1,40 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import List, Tuple
+from typing import List
 from dataclasses import dataclass
 from dataclasses import dataclass
 from .topology import Topology
 from .topology import Topology
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
+
 # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
 # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
 @dataclass
 @dataclass
 class Partition:
 class Partition:
-    node_id: str
-    start: float
-    end: float
+  node_id: str
+  start: float
+  end: float
+
 
 
 class PartitioningStrategy(ABC):
 class PartitioningStrategy(ABC):
-    @abstractmethod
-    def partition(self, topology: Topology) -> List[Partition]:
-        pass
+  @abstractmethod
+  def partition(self, topology: Topology) -> List[Partition]:
+    pass
+
 
 
 def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
 def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]:
-    shards = []
-    for i, partition in enumerate(partitions):
-        start_layer = int(partition.start * num_layers)
-        end_layer = int(partition.end * num_layers) - 1
+  shards = []
+  for i, partition in enumerate(partitions):
+    start_layer = int(partition.start * num_layers)
+    end_layer = int(partition.end * num_layers) - 1
 
 
-        # Ensure the last partition covers up to num_layers - 1
-        if i == len(partitions) - 1:
-            end_layer = num_layers - 1
+    # Ensure the last partition covers up to num_layers - 1
+    if i == len(partitions) - 1:
+      end_layer = num_layers - 1
 
 
-        # Ensure no empty shards
-        if start_layer <= end_layer:
-            shards.append(Shard(model_id, start_layer, end_layer, num_layers))
+    # Ensure no empty shards
+    if start_layer <= end_layer:
+      shards.append(Shard(model_id, start_layer, end_layer, num_layers))
 
 
-    # Ensure full coverage
-    if shards and shards[-1].end_layer < num_layers - 1:
-        shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
+  # Ensure full coverage
+  if shards and shards[-1].end_layer < num_layers - 1:
+    shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers)
 
 
-    return shards
+  return shards

+ 12 - 12
exo/topology/ring_memory_weighted_partitioning_strategy.py

@@ -1,18 +1,18 @@
 from typing import List
 from typing import List
 from .partitioning_strategy import PartitioningStrategy
 from .partitioning_strategy import PartitioningStrategy
-from exo.inference.shard import Shard
 from .topology import Topology
 from .topology import Topology
 from .partitioning_strategy import Partition
 from .partitioning_strategy import Partition
 
 
+
 class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
 class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy):
-    def partition(self, topology: Topology) -> List[Partition]:
-        nodes = list(topology.all_nodes())
-        nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)
-        total_memory = sum(node[1].memory for node in nodes)
-        partitions = []
-        start = 0
-        for node in nodes:
-            end = round(start + (node[1].memory / total_memory), 5)
-            partitions.append(Partition(node[0], start, end))
-            start = end
-        return partitions
+  def partition(self, topology: Topology) -> List[Partition]:
+    nodes = list(topology.all_nodes())
+    nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True)
+    total_memory = sum(node[1].memory for node in nodes)
+    partitions = []
+    start = 0
+    for node in nodes:
+      end = round(start + (node[1].memory / total_memory), 5)
+      partitions.append(Partition(node[0], start, end))
+      start = end
+    return partitions

+ 72 - 64
exo/topology/test_device_capabilities.py

@@ -2,82 +2,90 @@ import unittest
 from unittest.mock import patch
 from unittest.mock import patch
 from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS
 from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS
 
 
+
 class TestMacDeviceCapabilities(unittest.TestCase):
 class TestMacDeviceCapabilities(unittest.TestCase):
-    @patch('subprocess.check_output')
-    def test_mac_device_capabilities(self, mock_check_output):
-        # Mock the subprocess output
-        mock_check_output.return_value = b"""
+  @patch("subprocess.check_output")
+  def test_mac_device_capabilities_pro(self, mock_check_output):
+    # Mock the subprocess output
+    mock_check_output.return_value = b"""
 Hardware:
 Hardware:
 
 
-    Hardware Overview:
+Hardware Overview:
 
 
-        Model Name: MacBook Pro
-        Model Identifier: Mac15,9
-        Model Number: Z1CM000EFB/A
-        Chip: Apple M3 Max
-        Total Number of Cores: 16 (12 performance and 4 efficiency)
-        Memory: 128 GB
-        System Firmware Version: 10000.000.0
-        OS Loader Version: 10000.000.0
-        Serial Number (system): XXXXXXXXXX
-        Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
-        Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
-        Activation Lock Status: Enabled
-        """
+Model Name: MacBook Pro
+Model Identifier: Mac15,9
+Model Number: Z1CM000EFB/A
+Chip: Apple M3 Max
+Total Number of Cores: 16 (12 performance and 4 efficiency)
+Memory: 128 GB
+System Firmware Version: 10000.000.0
+OS Loader Version: 10000.000.0
+Serial Number (system): XXXXXXXXXX
+Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
+Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
+Activation Lock Status: Enabled
+"""
 
 
-        # Call the function
-        result = mac_device_capabilities()
+    # Call the function
+    result = mac_device_capabilities()
 
 
-        # Check the results
-        self.assertIsInstance(result, DeviceCapabilities)
-        self.assertEqual(result.model, "MacBook Pro")
-        self.assertEqual(result.chip, "Apple M3 Max")
-        self.assertEqual(result.memory, 131072)  # 16 GB in MB
-        self.assertEqual(str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS")
+    # Check the results
+    self.assertIsInstance(result, DeviceCapabilities)
+    self.assertEqual(result.model, "MacBook Pro")
+    self.assertEqual(result.chip, "Apple M3 Max")
+    self.assertEqual(result.memory, 131072)  # 16 GB in MB
+    self.assertEqual(
+      str(result),
+      "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
+    )
 
 
-    @patch('subprocess.check_output')
-    def test_mac_device_capabilities(self, mock_check_output):
-        # Mock the subprocess output
-        mock_check_output.return_value = b"""
+  @patch("subprocess.check_output")
+  def test_mac_device_capabilities_air(self, mock_check_output):
+    # Mock the subprocess output
+    mock_check_output.return_value = b"""
 Hardware:
 Hardware:
 
 
-    Hardware Overview:
+Hardware Overview:
+
+Model Name: MacBook Air
+Model Identifier: Mac14,2
+Model Number: MLY33B/A
+Chip: Apple M2
+Total Number of Cores: 8 (4 performance and 4 efficiency)
+Memory: 8 GB
+System Firmware Version: 10000.00.0
+OS Loader Version: 10000.00.0
+Serial Number (system): XXXXXXXXXX
+Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
+Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
+Activation Lock Status: Disabled
+"""
 
 
-      Model Name: MacBook Air
-      Model Identifier: Mac14,2
-      Model Number: MLY33B/A
-      Chip: Apple M2
-      Total Number of Cores: 8 (4 performance and 4 efficiency)
-      Memory: 8 GB
-      System Firmware Version: 10000.00.0
-      OS Loader Version: 10000.00.0
-      Serial Number (system): XXXXXXXXXX
-      Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
-      Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX
-      Activation Lock Status: Disabled
-        """
+    # Call the function
+    result = mac_device_capabilities()
 
 
-        # Call the function
-        result = mac_device_capabilities()
+    # Check the results
+    self.assertIsInstance(result, DeviceCapabilities)
+    self.assertEqual(result.model, "MacBook Air")
+    self.assertEqual(result.chip, "Apple M2")
+    self.assertEqual(result.memory, 8192)  # 8 GB in MB
 
 
-        # Check the results
-        self.assertIsInstance(result, DeviceCapabilities)
-        self.assertEqual(result.model, "MacBook Air")
-        self.assertEqual(result.chip, "Apple M2")
-        self.assertEqual(result.memory, 8192)  # 8 GB in MB
+  @unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
+  def test_mac_device_capabilities_real(self):
+    # Call the function without mocking
+    result = mac_device_capabilities()
 
 
-    @unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
-    def test_mac_device_capabilities_real(self):
-        # Call the function without mocking
-        result = mac_device_capabilities()
+    # Check the results
+    self.assertIsInstance(result, DeviceCapabilities)
+    self.assertEqual(result.model, "MacBook Pro")
+    self.assertEqual(result.chip, "Apple M3 Max")
+    self.assertEqual(result.memory, 131072)  # 128 GB in MB
+    self.assertEqual(result.flops, DeviceFlops(fp32=14.20 * TFLOPS, fp16=28.40 * TFLOPS, int8=56.80 * TFLOPS))
+    self.assertEqual(
+      str(result),
+      "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
+    )
 
 
-        # Check the results
-        self.assertIsInstance(result, DeviceCapabilities)
-        self.assertEqual(result.model, "MacBook Pro")
-        self.assertEqual(result.chip, "Apple M3 Max")
-        self.assertEqual(result.memory, 131072)  # 128 GB in MB
-        self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
-        self.assertEqual(str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS")
 
 
-if __name__ == '__main__':
-    unittest.main()
+if __name__ == "__main__":
+  unittest.main()

+ 68 - 55
exo/topology/test_map_partitions.py

@@ -3,66 +3,79 @@ from typing import List
 from exo.topology.partitioning_strategy import Partition, map_partitions_to_shards
 from exo.topology.partitioning_strategy import Partition, map_partitions_to_shards
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
+
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
-    def test_map_partitions_to_shards(self):
-        partitions = [
-            Partition('node1', 0.0, 0.42857),
-            Partition('node2', 0.42857, 0.71428),
-            Partition('node3', 0.71428, 0.99999),
-        ]
-        shards = map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [
-            Shard('model', 0, 12, 32),
-            Shard('model', 13, 21, 32),
-            Shard('model', 22, 31, 32),
-        ])
+  def test_map_partitions_to_shards(self):
+    partitions = [
+      Partition("node1", 0.0, 0.42857),
+      Partition("node2", 0.42857, 0.71428),
+      Partition("node3", 0.71428, 0.99999),
+    ]
+    shards = map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(
+      shards,
+      [
+        Shard("model", 0, 12, 32),
+        Shard("model", 13, 21, 32),
+        Shard("model", 22, 31, 32),
+      ],
+    )
 
 
-        partitions = [
-            Partition('node1', 0.0, 0.1),
-            Partition('node2', 0.1, 0.2),
-            Partition('node3', 0.2, 1.0),
-        ]
-        shards = map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [
-            Shard('model', 0, 2, 32),
-            Shard('model', 3, 5, 32),
-            Shard('model', 6, 31, 32),
-        ])
+    partitions = [
+      Partition("node1", 0.0, 0.1),
+      Partition("node2", 0.1, 0.2),
+      Partition("node3", 0.2, 1.0),
+    ]
+    shards = map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(
+      shards,
+      [
+        Shard("model", 0, 2, 32),
+        Shard("model", 3, 5, 32),
+        Shard("model", 6, 31, 32),
+      ],
+    )
 
 
-        partitions = [
-            Partition('node1', 0.0, 1.0),
-        ]
-        shards = map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [
-            Shard('model', 0, 31, 32),
-        ])
+    partitions = [
+      Partition("node1", 0.0, 1.0),
+    ]
+    shards = map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(
+      shards,
+      [
+        Shard("model", 0, 31, 32),
+      ],
+    )
 
 
-        partitions = []
-        shards = map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [])
+    partitions = []
+    shards = map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(shards, [])
 
 
-    def test_broken_map_partitions_to_shards(self):
-        # this was an old broken implementation that sometimes had rounding errors!
-        def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
-            shards = []
-            for i, partition in enumerate(partitions):
-                start_layer = int(partition.start * num_layers)
-                end_layer = int(partition.end * num_layers) - 1
-                shards.append(Shard(model_id, start_layer, end_layer, num_layers))
-            return shards
+  def test_broken_map_partitions_to_shards(self):
+    # this was an old broken implementation that sometimes had rounding errors!
+    def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str):
+      shards = []
+      for i, partition in enumerate(partitions):
+        start_layer = int(partition.start * num_layers)
+        end_layer = int(partition.end * num_layers) - 1
+        shards.append(Shard(model_id, start_layer, end_layer, num_layers))
+      return shards
 
 
-        partitions = [
-            Partition('node1', 0.0, 0.42857),
-            Partition('node2', 0.42857, 0.71428),
-            Partition('node3', 0.71428, 0.99999),
-        ]
-        shards = _broken_map_partitions_to_shards(partitions, 32, 'model')
-        self.assertEqual(shards, [
-            Shard('model', 0, 12, 32),
-            Shard('model', 13, 21, 32),
-            Shard('model', 22, 30, 32),
-        ])
+    partitions = [
+      Partition("node1", 0.0, 0.42857),
+      Partition("node2", 0.42857, 0.71428),
+      Partition("node3", 0.71428, 0.99999),
+    ]
+    shards = _broken_map_partitions_to_shards(partitions, 32, "model")
+    self.assertEqual(
+      shards,
+      [
+        Shard("model", 0, 12, 32),
+        Shard("model", 13, 21, 32),
+        Shard("model", 22, 30, 32),
+      ],
+    )
 
 
-if __name__ == '__main__':
-    unittest.main()
 
 
+if __name__ == "__main__":
+  unittest.main()

+ 83 - 42
exo/topology/test_ring_memory_weighted_partitioning_strategy.py

@@ -4,46 +4,87 @@ from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
 
 
+
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
 class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase):
-    def test_partition(self):
-        # triangle
-        # node1 -> node2 -> node3 -> node1
-        topology = Topology()
-        topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.add_edge('node1', 'node2')
-        topology.add_edge('node2', 'node3')
-        topology.add_edge('node3', 'node1')
-        topology.add_edge('node1', 'node3')
-
-        strategy = RingMemoryWeightedPartitioningStrategy()
-        partitions = strategy.partition(topology)
-
-        self.assertEqual(len(partitions), 3)
-        self.assertEqual(partitions, [
-            Partition('node3', 0.0, 0.6),
-            Partition('node1', 0.6, 0.9),
-            Partition('node2', 0.9, 1.0),
-        ])
-
-    def test_partition_rounding(self):
-        # triangle
-        # node1 -> node2 -> node3 -> node1
-        topology = Topology()
-        topology.update_node('node1', DeviceCapabilities(model="MacBook Pro", chip="test1", memory=128*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.update_node('node2', DeviceCapabilities(model="Mac Studio", chip="test2", memory=192*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-        topology.update_node('node3', DeviceCapabilities(model="MacBook Pro", chip="test3", memory=128*1024*1024*1024, flops=DeviceFlops(fp32=0, fp16=0, int8=0)))
-
-        strategy = RingMemoryWeightedPartitioningStrategy()
-        partitions = strategy.partition(topology)
-
-        self.assertEqual(len(partitions), 3)
-        self.assertEqual(partitions, [
-            Partition('node3', 0.0, 0.42857),
-            Partition('node1', 0.6, 0.9),
-            Partition('node2', 0.9, 1.0),
-        ])
-
-if __name__ == '__main__':
-    unittest.main()
+  def test_partition(self):
+    # triangle
+    # node1 -> node2 -> node3 -> node1
+    topology = Topology()
+    topology.update_node(
+      "node1",
+      DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+    topology.update_node(
+      "node2",
+      DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+    topology.update_node(
+      "node3",
+      DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)),
+    )
+    topology.add_edge("node1", "node2")
+    topology.add_edge("node2", "node3")
+    topology.add_edge("node3", "node1")
+    topology.add_edge("node1", "node3")
+
+    strategy = RingMemoryWeightedPartitioningStrategy()
+    partitions = strategy.partition(topology)
+
+    self.assertEqual(len(partitions), 3)
+    self.assertEqual(
+      partitions,
+      [
+        Partition("node3", 0.0, 0.6),
+        Partition("node1", 0.6, 0.9),
+        Partition("node2", 0.9, 1.0),
+      ],
+    )
+
+  def test_partition_rounding(self):
+    # triangle
+    # node1 -> node2 -> node3 -> node1
+    topology = Topology()
+    topology.update_node(
+      "node1",
+      DeviceCapabilities(
+        model="MacBook Pro",
+        chip="test1",
+        memory=128 * 1024 * 1024 * 1024,
+        flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+      ),
+    )
+    topology.update_node(
+      "node2",
+      DeviceCapabilities(
+        model="Mac Studio",
+        chip="test2",
+        memory=192 * 1024 * 1024 * 1024,
+        flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+      ),
+    )
+    topology.update_node(
+      "node3",
+      DeviceCapabilities(
+        model="MacBook Pro",
+        chip="test3",
+        memory=128 * 1024 * 1024 * 1024,
+        flops=DeviceFlops(fp32=0, fp16=0, int8=0),
+      ),
+    )
+
+    strategy = RingMemoryWeightedPartitioningStrategy()
+    partitions = strategy.partition(topology)
+
+    self.assertEqual(len(partitions), 3)
+    self.assertEqual(
+      partitions,
+      [
+        Partition("node3", 0.0, 0.42857),
+        Partition("node1", 0.6, 0.9),
+        Partition("node2", 0.9, 1.0),
+      ],
+    )
+
+
+if __name__ == "__main__":
+  unittest.main()

+ 45 - 44
exo/topology/topology.py

@@ -1,48 +1,49 @@
 from .device_capabilities import DeviceCapabilities
 from .device_capabilities import DeviceCapabilities
 from typing import Dict, Set, Optional
 from typing import Dict, Set, Optional
 
 
+
 class Topology:
 class Topology:
-    def __init__(self):
-        self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
-        self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
-        self.active_node_id: Optional[str] = None
-
-    def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
-        self.nodes[node_id] = device_capabilities
-
-    def get_node(self, node_id: str) -> DeviceCapabilities:
-        return self.nodes.get(node_id)
-
-    def all_nodes(self):
-        return self.nodes.items()
-
-    def add_edge(self, node1_id: str, node2_id: str):
-        if node1_id not in self.peer_graph:
-            self.peer_graph[node1_id] = set()
-        if node2_id not in self.peer_graph:
-            self.peer_graph[node2_id] = set()
-        self.peer_graph[node1_id].add(node2_id)
-        self.peer_graph[node2_id].add(node1_id)
-
-    def get_neighbors(self, node_id: str) -> Set[str]:
-        return self.peer_graph.get(node_id, set())
-
-    def all_edges(self):
-        edges = []
-        for node, neighbors in self.peer_graph.items():
-            for neighbor in neighbors:
-                if (neighbor, node) not in edges:  # Avoid duplicate edges
-                    edges.append((node, neighbor))
-        return edges
-
-    def merge(self, other: 'Topology'):
-        for node_id, capabilities in other.nodes.items():
-            self.update_node(node_id, capabilities)
-        for node_id, neighbors in other.peer_graph.items():
-            for neighbor in neighbors:
-                self.add_edge(node_id, neighbor)
-
-    def __str__(self):
-        nodes_str = ', '.join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
-        edges_str = ', '.join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
-        return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"
+  def __init__(self):
+    self.nodes: Dict[str, DeviceCapabilities] = {}  # Maps node IDs to DeviceCapabilities
+    self.peer_graph: Dict[str, Set[str]] = {}  # Adjacency list representing the graph
+    self.active_node_id: Optional[str] = None
+
+  def update_node(self, node_id: str, device_capabilities: DeviceCapabilities):
+    self.nodes[node_id] = device_capabilities
+
+  def get_node(self, node_id: str) -> DeviceCapabilities:
+    return self.nodes.get(node_id)
+
+  def all_nodes(self):
+    return self.nodes.items()
+
+  def add_edge(self, node1_id: str, node2_id: str):
+    if node1_id not in self.peer_graph:
+      self.peer_graph[node1_id] = set()
+    if node2_id not in self.peer_graph:
+      self.peer_graph[node2_id] = set()
+    self.peer_graph[node1_id].add(node2_id)
+    self.peer_graph[node2_id].add(node1_id)
+
+  def get_neighbors(self, node_id: str) -> Set[str]:
+    return self.peer_graph.get(node_id, set())
+
+  def all_edges(self):
+    edges = []
+    for node, neighbors in self.peer_graph.items():
+      for neighbor in neighbors:
+        if (neighbor, node) not in edges:  # Avoid duplicate edges
+          edges.append((node, neighbor))
+    return edges
+
+  def merge(self, other: "Topology"):
+    for node_id, capabilities in other.nodes.items():
+      self.update_node(node_id, capabilities)
+    for node_id, neighbors in other.peer_graph.items():
+      for neighbor in neighbors:
+        self.add_edge(node_id, neighbor)
+
+  def __str__(self):
+    nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())
+    edges_str = ", ".join(f"{node}: {neighbors}" for node, neighbors in self.peer_graph.items())
+    return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})"

+ 49 - 29
exo/viz/test_topology_viz.py

@@ -5,38 +5,58 @@ from exo.topology.topology import Topology
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.partitioning_strategy import Partition
 from exo.topology.partitioning_strategy import Partition
 
 
+
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
 class TestNodeViz(unittest.IsolatedAsyncioTestCase):
-    async def asyncSetUp(self):
-        self.topology = Topology()
-        self.topology.update_node("node1", DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0,fp16=2.0,int8=4.0)))
-        self.topology.update_node("node2", DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0,fp16=4.0,int8=8.0)))
-        self.topology.update_node("node3", DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0,fp16=8.0,int8=16.0)))
-        self.topology.update_node("node4", DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0,fp16=16.0,int8=32.0)))
+  async def asyncSetUp(self):
+    self.topology = Topology()
+    self.topology.update_node(
+      "node1",
+      DeviceCapabilities(model="ModelA", chip="ChipA", memory=8 * 1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)),
+    )
+    self.topology.update_node(
+      "node2",
+      DeviceCapabilities(model="ModelB", chip="ChipB", memory=16 * 1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)),
+    )
+    self.topology.update_node(
+      "node3",
+      DeviceCapabilities(model="ModelC", chip="ChipC", memory=32 * 1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)),
+    )
+    self.topology.update_node(
+      "node4",
+      DeviceCapabilities(model="ModelD", chip="ChipD", memory=64 * 1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)),
+    )
+
+    self.top_viz = TopologyViz()
+    await asyncio.sleep(2)  # Simulate running for a short time
 
 
-        self.top_viz = TopologyViz()
-        await asyncio.sleep(2)  # Simulate running for a short time
+  async def test_layout_generation(self):
+    self.top_viz._generate_layout()
+    self.top_viz.refresh()
+    import time
 
 
-    async def test_layout_generation(self):
-        self.top_viz._generate_layout()
-        self.top_viz.refresh()
-        import time
-        time.sleep(2)
-        self.top_viz.update_visualization(self.topology, [
-            Partition("node1", 0, 0.2),
-            Partition("node4", 0.2, 0.4),
-            Partition("node2", 0.4, 0.8),
-            Partition("node3", 0.8, 0.9),
-        ])
-        time.sleep(2)
-        self.topology.active_node_id = "node3"
-        self.top_viz.update_visualization(self.topology, [
-            Partition("node1", 0, 0.3),
-            Partition("node5", 0.3, 0.5),
-            Partition("node2", 0.5, 0.7),
-            Partition("node4", 0.7, 0.9),
-        ])
-        time.sleep(2)
+    time.sleep(2)
+    self.top_viz.update_visualization(
+      self.topology,
+      [
+        Partition("node1", 0, 0.2),
+        Partition("node4", 0.2, 0.4),
+        Partition("node2", 0.4, 0.8),
+        Partition("node3", 0.8, 0.9),
+      ],
+    )
+    time.sleep(2)
+    self.topology.active_node_id = "node3"
+    self.top_viz.update_visualization(
+      self.topology,
+      [
+        Partition("node1", 0, 0.3),
+        Partition("node5", 0.3, 0.5),
+        Partition("node2", 0.5, 0.7),
+        Partition("node4", 0.7, 0.9),
+      ],
+    )
+    time.sleep(2)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    unittest.main()
+  unittest.main()

+ 155 - 155
exo/viz/topology_viz.py

@@ -8,161 +8,161 @@ from rich.panel import Panel
 from rich.text import Text
 from rich.text import Text
 from rich.live import Live
 from rich.live import Live
 from rich.style import Style
 from rich.style import Style
-from rich.color import Color
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
 
 
+
 class TopologyViz:
 class TopologyViz:
-    def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
-        self.chatgpt_api_endpoint = chatgpt_api_endpoint
-        self.web_chat_url = web_chat_url
-        self.topology = Topology()
-        self.partitions: List[Partition] = []
-
-        self.console = Console()
-        self.panel = Panel(self._generate_layout(), title=f"Exo Cluster (0 nodes)", border_style="bright_yellow")
-        self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
-        self.live_panel.start()
-
-    def update_visualization(self, topology: Topology, partitions: List[Partition]):
-        self.topology = topology
-        self.partitions = partitions
-        self.refresh()
-
-    def refresh(self):
-        self.panel.renderable = self._generate_layout()
-        # Update the panel title with the number of nodes and partitions
-        node_count = len(self.topology.nodes)
-        self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
-        self.live_panel.update(self.panel, refresh=True)
-
-    def _generate_layout(self) -> str:
-        # Calculate visualization parameters
-        num_partitions = len(self.partitions)
-        radius_x = 30  # Increased horizontal radius
-        radius_y = 12  # Decreased vertical radius
-        center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
-
-        # Generate visualization
-        visualization = [[' ' for _ in range(100)] for _ in range(55)]  # Decreased height
-
-        # Add exo_text at the top in bright yellow
-        exo_lines = exo_text.split('\n')
-        yellow_style = Style(color="bright_yellow")
-        max_line_length = max(len(line) for line in exo_lines)
-        for i, line in enumerate(exo_lines):
-            centered_line = line.center(max_line_length)
-            start_x = (100 - max_line_length) // 2 + 15 # Center the text plus empirical adjustment of 15
-            colored_line = Text(centered_line, style=yellow_style)
-            for j, char in enumerate(str(colored_line)):
-                if 0 <= start_x + j < 100 and i < len(visualization):
-                    visualization[i][start_x + j] = char
-
-        # Display chatgpt_api_endpoint and web_chat_url if set
-        info_lines = []
-        if self.web_chat_url:
-            info_lines.append(f"Web Chat URL (tinychat): {self.web_chat_url}")
-        if self.chatgpt_api_endpoint:
-            info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}")
-
-        info_start_y = len(exo_lines) + 1
-        for i, line in enumerate(info_lines):
-            start_x = (100 - len(line)) // 2 + 15 # Center the info lines plus empirical adjustment of 15
-            for j, char in enumerate(line):
-                if 0 <= start_x + j < 100 and info_start_y + i < 55:
-                    visualization[info_start_y + i][start_x + j] = char
-
-        # Calculate total FLOPS and position on the bar
-        total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
-        bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
-
-        # Add GPU poor/rich bar
-        bar_width = 30  # Increased bar width
-        bar_start_x = (100 - bar_width) // 2  # Center the bar
-        bar_y = info_start_y + len(info_lines) + 1  # Position the bar below the info section with two cells of space
-        
-        # Create a gradient bar using emojis
-        gradient_bar = Text()
-        emojis = ['🟥', '🟧', '🟨', '🟩']  # Red, Orange, Yellow, Green
-        for i in range(bar_width):
-            emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
-            gradient_bar.append(emojis[emoji_index])
-
-        # Add the gradient bar to the visualization
-        visualization[bar_y][bar_start_x - 1] = '['
-        visualization[bar_y][bar_start_x + bar_width] = ']'
-        for i, segment in enumerate(str(gradient_bar)):
-            visualization[bar_y][bar_start_x + i] = segment
-        
-        # Add labels
-        visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = 'GPU poor'
-        visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = 'GPU rich'
-        
-        # Add position indicator and FLOPS value
-        pos_x = bar_start_x + int(bar_pos * bar_width)
-        flops_str = f"{total_flops:.2f} TFLOPS"
-        visualization[bar_y - 1][pos_x] = '▼'
-        visualization[bar_y + 1][pos_x - len(flops_str)//2:pos_x + len(flops_str)//2 + len(flops_str)%2] = flops_str
-        visualization[bar_y + 2][pos_x] = '▲'
-
-        for i, partition in enumerate(self.partitions):
-            device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
-
-            angle = 2 * math.pi * i / num_partitions
-            x = int(center_x + radius_x * math.cos(angle))
-            y = int(center_y + radius_y * math.sin(angle))
-
-            # Place node with different color for active node
-            if partition.node_id == self.topology.active_node_id:
-                visualization[y][x] = '🔴'  # Red circle for active node
-            else:
-                visualization[y][x] = '🔵'  # Blue circle for inactive nodes
-
-            # Place node info (model, memory, TFLOPS, partition) on three lines
-            node_info = [
-                f"{device_capabilities.model} {device_capabilities.memory // 1024}GB",
-                f"{device_capabilities.flops.fp16}TFLOPS",
-                f"[{partition.start:.2f}-{partition.end:.2f}]"
-            ]
-
-            # Calculate info position based on angle
-            info_distance_x = radius_x + 6  # Increased horizontal distance
-            info_distance_y = radius_y + 3  # Decreased vertical distance
-            info_x = int(center_x + info_distance_x * math.cos(angle))
-            info_y = int(center_y + info_distance_y * math.sin(angle))
-
-            # Adjust text position to avoid overwriting the node icon and prevent cutoff
-            if info_x < x:  # Text is to the left of the node
-                info_x = max(0, x - len(max(node_info, key=len)) - 1)
-            elif info_x > x:  # Text is to the right of the node
-                info_x = min(99 - len(max(node_info, key=len)), info_x)
-            
-            # Adjust for top and bottom nodes
-            if 5*math.pi/4 < angle < 7*math.pi/4:  # Node is near the top
-                info_x += 4  # Shift text slightly to the right
-            elif math.pi/4 < angle < 3*math.pi/4:  # Node is near the bottom
-                info_x += 3  # Shift text slightly to the right
-                info_y -= 2  # Move text up by two cells
-
-            for j, line in enumerate(node_info):
-                for k, char in enumerate(line):
-                    if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
-                        # Ensure we're not overwriting the node icon
-                        if info_y + j != y or info_x + k != x:
-                            visualization[info_y + j][info_x + k] = char
-
-            # Draw line to next node
-            next_i = (i + 1) % num_partitions
-            next_angle = 2 * math.pi * next_i / num_partitions
-            next_x = int(center_x + radius_x * math.cos(next_angle))
-            next_y = int(center_y + radius_y * math.sin(next_angle))
-
-            # Simple line drawing
-            steps = max(abs(next_x - x), abs(next_y - y))
-            for step in range(1, steps):
-                line_x = int(x + (next_x - x) * step / steps)
-                line_y = int(y + (next_y - y) * step / steps)
-                if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
-                    visualization[line_y][line_x] = '-'
-
-        # Convert to string
-        return '\n'.join(''.join(str(char) for char in row) for row in visualization)
+  def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None):
+    self.chatgpt_api_endpoint = chatgpt_api_endpoint
+    self.web_chat_url = web_chat_url
+    self.topology = Topology()
+    self.partitions: List[Partition] = []
+
+    self.console = Console()
+    self.panel = Panel(self._generate_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
+    self.live_panel = Live(self.panel, auto_refresh=False, console=self.console)
+    self.live_panel.start()
+
+  def update_visualization(self, topology: Topology, partitions: List[Partition]):
+    self.topology = topology
+    self.partitions = partitions
+    self.refresh()
+
+  def refresh(self):
+    self.panel.renderable = self._generate_layout()
+    # Update the panel title with the number of nodes and partitions
+    node_count = len(self.topology.nodes)
+    self.panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
+    self.live_panel.update(self.panel, refresh=True)
+
+  def _generate_layout(self) -> str:
+    # Calculate visualization parameters
+    num_partitions = len(self.partitions)
+    radius_x = 30  # Increased horizontal radius
+    radius_y = 12  # Decreased vertical radius
+    center_x, center_y = 50, 28  # Centered horizontally and moved up slightly
+
+    # Generate visualization
+    visualization = [[" " for _ in range(100)] for _ in range(55)]  # Decreased height
+
+    # Add exo_text at the top in bright yellow
+    exo_lines = exo_text.split("\n")
+    yellow_style = Style(color="bright_yellow")
+    max_line_length = max(len(line) for line in exo_lines)
+    for i, line in enumerate(exo_lines):
+      centered_line = line.center(max_line_length)
+      start_x = (100 - max_line_length) // 2 + 15  # Center the text plus empirical adjustment of 15
+      colored_line = Text(centered_line, style=yellow_style)
+      for j, char in enumerate(str(colored_line)):
+        if 0 <= start_x + j < 100 and i < len(visualization):
+          visualization[i][start_x + j] = char
+
+    # Display chatgpt_api_endpoint and web_chat_url if set
+    info_lines = []
+    if self.web_chat_url:
+      info_lines.append(f"Web Chat URL (tinychat): {self.web_chat_url}")
+    if self.chatgpt_api_endpoint:
+      info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}")
+
+    info_start_y = len(exo_lines) + 1
+    for i, line in enumerate(info_lines):
+      start_x = (100 - len(line)) // 2 + 15  # Center the info lines plus empirical adjustment of 15
+      for j, char in enumerate(line):
+        if 0 <= start_x + j < 100 and info_start_y + i < 55:
+          visualization[info_start_y + i][start_x + j] = char
+
+    # Calculate total FLOPS and position on the bar
+    total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
+    bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2
+
+    # Add GPU poor/rich bar
+    bar_width = 30  # Increased bar width
+    bar_start_x = (100 - bar_width) // 2  # Center the bar
+    bar_y = info_start_y + len(info_lines) + 1  # Position the bar below the info section with two cells of space
+
+    # Create a gradient bar using emojis
+    gradient_bar = Text()
+    emojis = ["🟥", "🟧", "🟨", "🟩"]  # Red, Orange, Yellow, Green
+    for i in range(bar_width):
+      emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1)
+      gradient_bar.append(emojis[emoji_index])
+
+    # Add the gradient bar to the visualization
+    visualization[bar_y][bar_start_x - 1] = "["
+    visualization[bar_y][bar_start_x + bar_width] = "]"
+    for i, segment in enumerate(str(gradient_bar)):
+      visualization[bar_y][bar_start_x + i] = segment
+
+    # Add labels
+    visualization[bar_y - 1][bar_start_x - 10 : bar_start_x - 3] = "GPU poor"
+    visualization[bar_y - 1][bar_start_x + bar_width * 2 + 2 : bar_start_x + bar_width * 2 + 11] = "GPU rich"
+
+    # Add position indicator and FLOPS value
+    pos_x = bar_start_x + int(bar_pos * bar_width)
+    flops_str = f"{total_flops:.2f} TFLOPS"
+    visualization[bar_y - 1][pos_x] = "▼"
+    visualization[bar_y + 1][pos_x - len(flops_str) // 2 : pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
+    visualization[bar_y + 2][pos_x] = "▲"
+
+    for i, partition in enumerate(self.partitions):
+      device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
+
+      angle = 2 * math.pi * i / num_partitions
+      x = int(center_x + radius_x * math.cos(angle))
+      y = int(center_y + radius_y * math.sin(angle))
+
+      # Place node with different color for active node
+      if partition.node_id == self.topology.active_node_id:
+        visualization[y][x] = "🔴"  # Red circle for active node
+      else:
+        visualization[y][x] = "🔵"  # Blue circle for inactive nodes
+
+      # Place node info (model, memory, TFLOPS, partition) on three lines
+      node_info = [
+        f"{device_capabilities.model} {device_capabilities.memory // 1024}GB",
+        f"{device_capabilities.flops.fp16}TFLOPS",
+        f"[{partition.start:.2f}-{partition.end:.2f}]",
+      ]
+
+      # Calculate info position based on angle
+      info_distance_x = radius_x + 6  # Increased horizontal distance
+      info_distance_y = radius_y + 3  # Decreased vertical distance
+      info_x = int(center_x + info_distance_x * math.cos(angle))
+      info_y = int(center_y + info_distance_y * math.sin(angle))
+
+      # Adjust text position to avoid overwriting the node icon and prevent cutoff
+      if info_x < x:  # Text is to the left of the node
+        info_x = max(0, x - len(max(node_info, key=len)) - 1)
+      elif info_x > x:  # Text is to the right of the node
+        info_x = min(99 - len(max(node_info, key=len)), info_x)
+
+      # Adjust for top and bottom nodes
+      if 5 * math.pi / 4 < angle < 7 * math.pi / 4:  # Node is near the top
+        info_x += 4  # Shift text slightly to the right
+      elif math.pi / 4 < angle < 3 * math.pi / 4:  # Node is near the bottom
+        info_x += 3  # Shift text slightly to the right
+        info_y -= 2  # Move text up by two cells
+
+      for j, line in enumerate(node_info):
+        for k, char in enumerate(line):
+          if 0 <= info_y + j < 55 and 0 <= info_x + k < 100:  # Updated height check
+            # Ensure we're not overwriting the node icon
+            if info_y + j != y or info_x + k != x:
+              visualization[info_y + j][info_x + k] = char
+
+      # Draw line to next node
+      next_i = (i + 1) % num_partitions
+      next_angle = 2 * math.pi * next_i / num_partitions
+      next_x = int(center_x + radius_x * math.cos(next_angle))
+      next_y = int(center_y + radius_y * math.sin(next_angle))
+
+      # Simple line drawing
+      steps = max(abs(next_x - x), abs(next_y - y))
+      for step in range(1, steps):
+        line_x = int(x + (next_x - x) * step / steps)
+        line_y = int(y + (next_y - y) * step / steps)
+        if 0 <= line_y < 55 and 0 <= line_x < 100:  # Updated height check
+          visualization[line_y][line_x] = "-"
+
+    # Convert to string
+    return "\n".join("".join(str(char) for char in row) for row in visualization)

+ 110 - 0
format.py

@@ -0,0 +1,110 @@
+#!/usr/bin/env python
+import re
+import subprocess
+import sys
+import os
+import fnmatch
+
+DEBUG_PATTERN = re.compile(r'^(\s*)(if\s+DEBUG\s*>=?\s*\d+\s*:.+)$', re.MULTILINE)
+PLACEHOLDER = "###DEBUG_PLACEHOLDER###"
+
+# Add ignore patterns here
+IGNORE_PATTERNS = [
+  '.venv/*',
+  'setup.py',
+  '*helpers.py',
+  '*node_service_pb2.py',
+  '*node_service_pb2_grpc.py',
+]
+
+
+def should_ignore(file_path):
+  for pattern in IGNORE_PATTERNS:
+    if fnmatch.fnmatch(file_path, pattern):
+      return True
+  return False
+
+
+def preserve_debug_lines(content):
+  def replace(match):
+    indent, line = match.groups()
+    return f"{indent}{PLACEHOLDER}{line.strip()}"
+
+  return DEBUG_PATTERN.sub(replace, content)
+
+
+def restore_debug_lines(content):
+  return re.sub(f"^(\\s*){PLACEHOLDER}(.+)$", r"\1\2", content, flags=re.MULTILINE)
+
+
+def adjust_indentation(content):
+  lines = content.split('\n')
+  adjusted_lines = []
+  for line in lines:
+    if line.strip() and not line.startswith(PLACEHOLDER):
+      indent = len(line) - len(line.lstrip())
+      new_indent = ' ' * (indent // 2)
+      adjusted_lines.append(new_indent + line.lstrip())
+    else:
+      adjusted_lines.append(line)
+  return '\n'.join(adjusted_lines)
+
+
+def process_file(file_path, process_func):
+  with open(file_path, 'r') as file:
+    content = file.read()
+
+  modified_content = process_func(content)
+
+  if content != modified_content:
+    with open(file_path, 'w') as file:
+      file.write(modified_content)
+
+
+def run_black(target):
+  # Convert ignore patterns to Black's --extend-exclude format
+  exclude_patterns = '|'.join(f'({pattern.replace("*", ".*")})' for pattern in IGNORE_PATTERNS)
+  command = ["black", "--line-length", "200", "--extend-exclude", exclude_patterns, target]
+  subprocess.run(command, check=True)
+
+
+def format_files(target):
+  if os.path.isfile(target):
+    files = [target] if not should_ignore(target) else []
+  elif os.path.isdir(target):
+    files = []
+    for root, _, filenames in os.walk(target):
+      for filename in filenames:
+        if filename.endswith('.py'):
+          file_path = os.path.join(root, filename)
+          if not should_ignore(file_path):
+            files.append(file_path)
+  else:
+    print(f"Error: {target} is not a valid file or directory")
+    return
+
+  # Preserve debug lines
+  for file in files:
+    process_file(file, preserve_debug_lines)
+
+  # Run Black
+  run_black(target)
+
+  # Adjust indentation and restore debug lines
+  for file in files:
+    process_file(file, adjust_indentation)
+    process_file(file, restore_debug_lines)
+
+
+def main():
+  if len(sys.argv) < 2:
+    print("Usage: python format.py <directory_or_file>")
+    sys.exit(1)
+
+  target = sys.argv[1]
+  format_files(target)
+  print("Formatting completed.")
+
+
+if __name__ == "__main__":
+  main()

+ 5 - 0
lint.sh

@@ -0,0 +1,5 @@
+#!/bin/bash
+
+pip3 install -e '.[linting]'
+python3 -m ruff check .
+python3 -m pylint .

+ 12 - 3
main.py

@@ -2,7 +2,6 @@ import argparse
 import asyncio
 import asyncio
 import signal
 import signal
 import uuid
 import uuid
-from typing import List
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
 from exo.networking.grpc.grpc_discovery import GRPCDiscovery
@@ -41,11 +40,21 @@ if args.node_port is None:
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
     if DEBUG >= 1: print(f"Using available port: {args.node_port}")
 
 
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
 discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout)
-node = StandardNode(args.node_id, None, inference_engine, discovery, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions", web_chat_url=f"http://localhost:{args.chatgpt_api_port}", disable_tui=args.disable_tui, max_generate_tokens=args.max_generate_tokens)
+node = StandardNode(
+    args.node_id,
+    None,
+    inference_engine,
+    discovery,
+    partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
+    chatgpt_api_endpoint=f"http://localhost:{args.chatgpt_api_port}/v1/chat/completions",
+    web_chat_url=f"http://localhost:{args.chatgpt_api_port}",
+    disable_tui=args.disable_tui,
+    max_generate_tokens=args.max_generate_tokens,
+)
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
 api = ChatGPTAPI(node, inference_engine.__class__.__name__, response_timeout_secs=args.chatgpt_api_response_timeout_secs)
-node.on_token.register("main_log").on_next(lambda _, tokens , __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
+node.on_token.register("main_log").on_next(lambda _, tokens, __: print(inference_engine.tokenizer.decode(tokens) if hasattr(inference_engine, "tokenizer") else tokens))
 if args.prometheus_client_port:
 if args.prometheus_client_port:
     from exo.stats.metrics import start_metrics_server
     from exo.stats.metrics import start_metrics_server
     start_metrics_server(node, args.prometheus_client_port)
     start_metrics_server(node, args.prometheus_client_port)

+ 17 - 0
pyproject.toml

@@ -0,0 +1,17 @@
+[tool.black]
+line-length = 200
+indent-size = 2
+skip-string-normalization = true
+
+[tool.isort]
+profile = "black"
+line_length = 200
+indent = "  "
+
+[tool.pylint.format]
+indent-string = '  '
+max-line-length = 200
+
+[tool.autopep8]
+max_line_length = 200
+indent_size = 2

+ 43 - 0
ruff.toml

@@ -0,0 +1,43 @@
+indent-width = 2
+preview = true
+target-version = "py312"
+
+lint.select = [
+  "F",  # Pyflakes
+  "W6",
+  "E71",
+  "E72",
+  "E112",   # no-indented-block
+  "E113",   # unexpected-indentation
+  # "E124",
+  "E203",   # whitespace-before-punctuation
+  "E272",   # multiple-spaces-before-keyword
+  "E303",   # too-many-blank-lines
+  "E304",   # blank-line-after-decorator
+  "E501",   # line-too-long
+  # "E502",
+  "E702",   # multiple-statements-on-one-line-semicolon
+  "E703",   # useless-semicolon
+  "E731",   # lambda-assignment
+  "W191",   # tab-indentation
+  "W291",   # trailing-whitespace
+  "W293",   # blank-line-with-whitespace
+  "UP039",  # unnecessary-class-parentheses
+  "C416",   # unnecessary-comprehension
+  "RET506", # superfluous-else-raise
+  "RET507", # superfluous-else-continue
+  "A",      # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
+  "SIM105", # suppressible-exception
+  "FURB110",# if-exp-instead-of-or-operator
+]
+
+line-length = 200
+
+exclude = [
+  "docs/",
+  "examples/",
+  "extra/",
+  "exo/networking/grpc/node_service_pb2.py",
+  "exo/networking/grpc/node_service_pb2_grpc.py",
+  "exo/helpers.py",
+]

+ 10 - 1
setup.py

@@ -1,6 +1,7 @@
-from setuptools import setup, find_packages
 import sys
 import sys
 
 
+from setuptools import find_packages, setup
+
 # Base requirements for all platforms
 # Base requirements for all platforms
 install_requires = [
 install_requires = [
     "aiohttp==3.9.5",
     "aiohttp==3.9.5",
@@ -35,10 +36,18 @@ if sys.platform.startswith("darwin"):
         ]
         ]
     )
     )
 
 
+extras_require = {
+    "linting": [
+        "pylint==3.2.6",
+        "ruff==0.5.5",
+        "mypy==1.11.0",
+    ]
+}
 
 
 setup(
 setup(
     name="exo",
     name="exo",
     version="0.0.1",
     version="0.0.1",
     packages=find_packages(),
     packages=find_packages(),
     install_requires=install_requires,
     install_requires=install_requires,
+    extras_require=extras_require,
 )
 )

+ 1 - 0
tinychat/examples/tinychat/index.html

@@ -64,6 +64,7 @@
         <option value="llama-3-70b">Llama 3 70B</option>
         <option value="llama-3-70b">Llama 3 70B</option>
         <option value="mistral-nemo">Mistral Nemo</option>
         <option value="mistral-nemo">Mistral Nemo</option>
         <option value="mistral-large">Mistral Large</option>
         <option value="mistral-large">Mistral Large</option>
+        <option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
       </select>
       </select>
     </div>
     </div>
     <div class="home centered" x-show="home === 0" x-transition x-effect="
     <div class="home centered" x-show="home === 0" x-transition x-effect="