Browse Source

Merge branch 'main' into main

Alex Cheema 6 months ago
parent
commit
675b1e16a8
64 changed files with 1584 additions and 1402 deletions
  1. 35 35
      .circleci/config.yml
  2. 1 1
      .gitignore
  3. 0 472
      .pylintrc
  4. 25 5
      README.md
  5. 18 2
      configure_mlx.sh
  6. BIN
      docs/exo-rounded.png
  7. 1 1
      examples/astra/astra/ContentView.swift
  8. 1 1
      examples/chatgpt_api.sh
  9. 1 1
      exo/__init__.py
  10. 53 41
      exo/api/chatgpt_api.py
  11. 34 6
      exo/download/hf/hf_helpers.py
  12. 8 6
      exo/download/hf/hf_shard_download.py
  13. 4 2
      exo/download/shard_download.py
  14. 21 1
      exo/helpers.py
  15. 11 12
      exo/inference/debug_inference_engine.py
  16. 17 38
      exo/inference/dummy_inference_engine.py
  17. 22 3
      exo/inference/inference_engine.py
  18. 1 1
      exo/inference/mlx/models/base.py
  19. 1 1
      exo/inference/mlx/models/deepseek_v2.py
  20. 118 0
      exo/inference/mlx/models/gemma2.py
  21. 2 1
      exo/inference/mlx/models/qwen2.py
  22. 46 21
      exo/inference/mlx/sharded_inference_engine.py
  23. 0 86
      exo/inference/mlx/sharded_model.py
  24. 14 8
      exo/inference/mlx/sharded_utils.py
  25. 42 0
      exo/inference/mlx/stateful_model.py
  26. 4 4
      exo/inference/mlx/test_sharded_llama.py
  27. 2 2
      exo/inference/mlx/test_sharded_llava.py
  28. 39 42
      exo/inference/test_dummy_inference_engine.py
  29. 14 24
      exo/inference/test_inference_engine.py
  30. 34 36
      exo/inference/tinygrad/inference.py
  31. 61 36
      exo/inference/tinygrad/models/llama.py
  32. 42 0
      exo/inference/tinygrad/stateful_model.py
  33. 6 1
      exo/inference/tokenizers.py
  34. 76 40
      exo/main.py
  35. 125 50
      exo/models.py
  36. 11 8
      exo/networking/grpc/grpc_peer_handle.py
  37. 3 5
      exo/networking/grpc/grpc_server.py
  38. 2 5
      exo/networking/grpc/node_service.proto
  39. 0 0
      exo/networking/grpc/node_service_pb2.py
  40. 9 7
      exo/networking/manual/manual_discovery.py
  41. 0 1
      exo/networking/manual/network_topology_config.py
  42. 3 2
      exo/networking/peer_handle.py
  43. 11 11
      exo/networking/tailscale/tailscale_discovery.py
  44. 15 26
      exo/networking/tailscale/tailscale_helpers.py
  45. 2 0
      exo/networking/tailscale/test_tailscale_discovery.py
  46. 14 15
      exo/networking/udp/udp_discovery.py
  47. 2 2
      exo/orchestration/node.py
  48. 112 122
      exo/orchestration/standard_node.py
  49. 129 62
      exo/tinychat/index.css
  50. 26 25
      exo/tinychat/index.html
  51. 109 14
      exo/tinychat/index.js
  52. 44 41
      exo/tinychat/update_deps.py
  53. 4 2
      exo/topology/device_capabilities.py
  54. 1 1
      extra/start_openwebui.sh
  55. 1 1
      format.py
  56. 0 5
      lint.sh
  57. 0 7
      pyproject.toml
  58. 0 43
      ruff.toml
  59. 60 0
      scripts/build_exo.py
  60. 7 0
      scripts/compile_grpc.sh
  61. 10 13
      setup.py
  62. 1 1
      test/reconnect.sh
  63. 121 0
      test/test_model_helpers.py
  64. 8 3
      test/test_tokenizers.py

+ 35 - 35
.circleci/config.yml

@@ -20,12 +20,18 @@ commands:
           command: |
             source env/bin/activate
 
+            # Set CLANG=1 for tinygrad only
+            if [ "<<parameters.inference_engine>>" = "tinygrad" ]; then
+              pip install llvmlite
+              export TOKENIZERS_PARALLELISM=true SUPPORT_BF16=0 CLANG=1
+            fi
+
             # Start first instance
-            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
+            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 --disable-tui 2>&1 | tee output1.log &
             PID1=$!
 
             # Start second instance
-            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log &
+            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 --disable-tui 2>&1 | tee output2.log &
             PID2=$!
 
             # Wait for discovery
@@ -48,13 +54,6 @@ commands:
             # Check processes before proceeding
             check_processes
 
-            # Special handling for dummy engine
-            if [ "<<parameters.inference_engine>>" = "dummy" ]; then
-              expected_content="This is a dummy response"
-            else
-              expected_content="Michael Jackson"
-            fi
-
             echo "Sending request to first instance..."
             response_1=$(curl -s http://localhost:8000/v1/chat/completions \
               -H "Content-Type: application/json" \
@@ -127,6 +126,7 @@ jobs:
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             echo "Running tokenizer tests..."
             python3 ./test/test_tokenizers.py
+            python3 ./test/test_model_helpers.py
 
   discovery_integration_test:
     macos:
@@ -149,9 +149,9 @@ jobs:
           name: Run discovery integration test
           command: |
             source env/bin/activate
-            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --disable-tui > output1.log 2>&1 &
             PID1=$!
-            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --disable-tui > output2.log 2>&1 &
             PID2=$!
             sleep 10
             kill $PID1 $PID2
@@ -223,29 +223,29 @@ jobs:
       - checkout
       - run: system_profiler SPHardwareDataType
 
-  # chatgpt_api_integration_test_tinygrad:
-  #   macos:
-  #     xcode: "16.0.0"
-  #   resource_class: m2pro.large
-  #   steps:
-  #     - checkout
-  #     - run:
-  #         name: Set up Python
-  #         command: |
-  #           brew install python@3.12
-  #           python3.12 -m venv env
-  #           source env/bin/activate
-  #     - run:
-  #         name: Install dependencies
-  #         command: |
-  #           source env/bin/activate
-  #           pip install --upgrade pip
-  #           pip install .
-  #     - run_chatgpt_api_test:
-  #         inference_engine: tinygrad
-  #         model_id: llama-3-8b
-  #         prompt: "Keep responses concise. Who was the king of pop?"
-  #         expected_output: "Michael Jackson"
+  chatgpt_api_integration_test_tinygrad:
+    macos:
+      xcode: "16.0.0"
+    resource_class: m2pro.large
+    steps:
+      - checkout
+      - run:
+          name: Set up Python
+          command: |
+            brew install python@3.12
+            python3.12 -m venv env
+            source env/bin/activate
+      - run:
+          name: Install dependencies
+          command: |
+            source env/bin/activate
+            pip install --upgrade pip
+            pip install .
+      - run_chatgpt_api_test:
+          inference_engine: tinygrad
+          model_id: llama-3.2-1b
+          prompt: "Keep responses concise. Who was the king of pop?"
+          expected_output: "Michael Jackson"
 
 workflows:
   version: 2
@@ -254,6 +254,6 @@ workflows:
       - unit_test
       - discovery_integration_test
       - chatgpt_api_integration_test_mlx
+      - chatgpt_api_integration_test_tinygrad
       - chatgpt_api_integration_test_dummy
       - test_macos_m1
-      # - chatgpt_api_integration_test_tinygrad

+ 1 - 1
.gitignore

@@ -4,6 +4,7 @@ test_weights.npz
 .exo_used_ports
 .exo_node_id
 .idea
+.DS_Store
 
 # Byte-compiled / optimized / DLL files
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 # Distribution / packaging
 /.Python
-/build/
 /develop-eggs/
 /dist/
 /downloads/

+ 0 - 472
.pylintrc

@@ -1,472 +0,0 @@
-[MASTER]
-
-# A comma-separated list of package or module names from where C extensions may
-# be loaded. Extensions are loading into the active Python interpreter and may
-# run arbitrary code
-extension-pkg-whitelist=scipy,cereal.messaging.messaging_pyx,PyQt5,av
-
-# Add files or directories to the blacklist. They should be base names, not
-# paths.
-ignore=CVS
-
-# Add files or directories matching the regex patterns to the blacklist. The
-# regex matches against base names, not paths.
-ignore-patterns=.*node_service_pb2.*
-
-# Python code to execute, usually for sys.path manipulation such as
-# pygtk.require().
-#init-hook=
-
-# Use multiple processes to speed up Pylint.
-jobs=4
-
-# List of plugins (as comma separated values of python modules names) to load,
-# usually to register additional checkers.
-load-plugins=
-
-# Pickle collected data for later comparisons.
-persistent=yes
-
-# Specify a configuration file.
-#rcfile=
-
-# When enabled, pylint would attempt to guess common misconfiguration and emit
-# user-friendly hints instead of false-positive error messages
-suggestion-mode=yes
-
-# Allow loading of arbitrary C extensions. Extensions are imported into the
-# active Python interpreter and may run arbitrary code.
-unsafe-load-any-extension=no
-
-
-[MESSAGES CONTROL]
-
-# Only show warnings with the listed confidence levels. Leave empty to show
-# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
-confidence=
-
-# Disable the message, report, category or checker with the given id(s). You
-# can either give multiple identifiers separated by comma (,) or put this
-# option multiple times (only on the command line, not in the configuration
-# file where it should appear only once).You can also use "--disable=all" to
-# disable everything first and then reenable specific checks. For example, if
-# you want to run only the similarities checker, you can use "--disable=all
-# --enable=similarities". If you want to run only the classes checker, but have
-# no Warning level messages displayed, use"--disable=all --enable=classes
-# --disable=W"
-disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401
-# E1101 for function binding
-# W0221 for Function class
-# W0105 for comment strings
-# E0401 for missing imports
-
-# Enable the message, report, category or checker with the given id(s). You can
-# either give multiple identifier separated by comma (,) or put this option
-# multiple time (only on the command line, not in the configuration file where
-# it should appear only once). See also the "--disable" option for examples.
-enable=c-extension-no-member,use-a-generator, no-else-return
-
-
-[REPORTS]
-
-# Python expression which should return a note less than 10 (10 is the highest
-# note). You have access to the variables errors warning, statement which
-# respectively contain the number of errors / warnings messages and the total
-# number of statements analyzed. This is used by the global evaluation report
-# (RP0004).
-evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
-
-# Template used to display messages. This is a python new-style format string
-# used to format the message information. See doc for all details
-#msg-template=
-
-# Set the output format. Available formats are text, parseable, colorized, json
-# and msvs (visual studio).You can also give a reporter class, eg
-# mypackage.mymodule.MyReporterClass.
-output-format=text
-
-# Tells whether to display a full report or only the messages
-reports=no
-
-# Activate the evaluation score.
-score=yes
-
-
-[REFACTORING]
-
-# Maximum number of nested blocks for function / method body
-max-nested-blocks=5
-
-# Complete name of functions that never returns. When checking for
-# inconsistent-return-statements if a never returning function is called then
-# it will be considered as an explicit return statement and no message will be
-# printed.
-never-returning-functions=optparse.Values,sys.exit
-
-
-[LOGGING]
-
-# Logging modules to check that the string format arguments are in logging
-# function parameter format
-logging-modules=logging
-
-
-[SPELLING]
-
-# Limits count of emitted suggestions for spelling mistakes
-max-spelling-suggestions=4
-
-# Spelling dictionary name. Available dictionaries: none. To make it working
-# install python-enchant package.
-spelling-dict=
-
-# List of comma separated words that should not be checked.
-spelling-ignore-words=
-
-# A path to a file that contains private dictionary; one word per line.
-spelling-private-dict-file=
-
-# Tells whether to store unknown words to indicated private dictionary in
-# --spelling-private-dict-file option instead of raising a message.
-spelling-store-unknown-words=no
-
-
-[MISCELLANEOUS]
-
-# List of note tags to take in consideration, separated by a comma.
-notes=FIXME,
-      XXX,
-      TODO
-
-
-[SIMILARITIES]
-
-# Ignore comments when computing similarities.
-ignore-comments=yes
-
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
-
-# Ignore imports when computing similarities.
-ignore-imports=no
-
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-
-
-[TYPECHECK]
-
-# List of decorators that produce context managers, such as
-# contextlib.contextmanager. Add to this list to register other decorators that
-# produce valid context managers.
-contextmanager-decorators=contextlib.contextmanager
-
-# List of members which are set dynamically and missed by pylint inference
-# system, and so shouldn't trigger E1101 when accessed. Python regular
-# expressions are accepted.
-generated-members=capnp.* cereal.* pygame.* zmq.* setproctitle.* smbus2.* usb1.* serial.* cv2.* ft4222.* carla.*
-
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members=yes
-
-# This flag controls whether pylint should warn about no-member and similar
-# checks whenever an opaque object is returned when inferring. The inference
-# can return multiple potential results while evaluating a Python object, but
-# some branches might not be evaluated, which results in partial inference. In
-# that case, it might be useful to still emit no-member and other checks for
-# the rest of the inferred objects.
-ignore-on-opaque-inference=yes
-
-# List of class names for which member attributes should not be checked (useful
-# for classes with dynamically set attributes). This supports the use of
-# qualified names.
-ignored-classes=optparse.Values,thread._local,_thread._local
-
-# List of module names for which member attributes should not be checked
-# (useful for modules/projects where namespaces are manipulated during runtime
-# and thus existing member attributes cannot be deduced by static analysis. It
-# supports qualified module names, as well as Unix pattern matching.
-ignored-modules=flask setproctitle usb1 flask.ext.socketio smbus2 usb1.*
-
-# Show a hint with possible names when a member name was not found. The aspect
-# of finding the hint is based on edit distance.
-missing-member-hint=yes
-
-# The minimum edit distance a name should have in order to be considered a
-# similar match for a missing member name.
-missing-member-hint-distance=1
-
-# The total number of similar names that should be taken in consideration when
-# showing a hint for a missing member.
-missing-member-max-choices=1
-
-
-[VARIABLES]
-
-# List of additional names supposed to be defined in builtins. Remember that
-# you should avoid to define new builtins when possible.
-additional-builtins=
-
-# Tells whether unused global variables should be treated as a violation.
-allow-global-unused-variables=yes
-
-# List of strings which can identify a callback function by name. A callback
-# name must start or end with one of those strings.
-callbacks=cb_,
-          _cb
-
-# A regular expression matching the name of dummy variables (i.e. expectedly
-# not used).
-dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
-
-# Argument names that match this expression will be ignored. Default to name
-# with leading underscore
-ignored-argument-names=_.*|^ignored_|^unused_
-
-# Tells whether we should check for unused import in __init__ files.
-init-import=no
-
-# List of qualified module names which can have objects that can redefine
-# builtins.
-redefining-builtins-modules=six.moves,past.builtins,future.builtins
-
-
-[FORMAT]
-
-# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
-expected-line-ending-format=
-
-# Regexp for a line that is allowed to be longer than the limit.
-ignore-long-lines=^\s*(# )?<?https?://\S+>?$
-
-# Number of spaces of indent required inside a hanging  or continued line.
-indent-after-paren=4
-
-# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
-# tab).
-indent-string='  '
-
-# Maximum number of characters on a single line.
-max-line-length=150
-
-# Maximum number of lines in a module
-max-module-lines=1000
-
-# Allow the body of a class to be on the same line as the declaration if body
-# contains single statement.
-single-line-class-stmt=no
-
-# Allow the body of an if to be on the same line as the test if there is no
-# else.
-single-line-if-stmt=no
-
-
-[BASIC]
-
-# Naming style matching correct argument names
-argument-naming-style=snake_case
-
-# Regular expression matching correct argument names. Overrides argument-
-# naming-style
-#argument-rgx=
-
-# Naming style matching correct attribute names
-attr-naming-style=snake_case
-
-# Regular expression matching correct attribute names. Overrides attr-naming-
-# style
-#attr-rgx=
-
-# Bad variable names which should always be refused, separated by a comma
-bad-names=foo,
-          bar,
-          baz,
-          toto,
-          tutu,
-          tata
-
-# Naming style matching correct class attribute names
-class-attribute-naming-style=any
-
-# Regular expression matching correct class attribute names. Overrides class-
-# attribute-naming-style
-#class-attribute-rgx=
-
-# Naming style matching correct class names
-class-naming-style=PascalCase
-
-# Regular expression matching correct class names. Overrides class-naming-style
-#class-rgx=
-
-# Naming style matching correct constant names
-const-naming-style=UPPER_CASE
-
-# Regular expression matching correct constant names. Overrides const-naming-
-# style
-#const-rgx=
-
-# Minimum line length for functions/classes that require docstrings, shorter
-# ones are exempt.
-docstring-min-length=-1
-
-# Naming style matching correct function names
-function-naming-style=snake_case
-
-# Regular expression matching correct function names. Overrides function-
-# naming-style
-#function-rgx=
-
-# Good variable names which should always be accepted, separated by a comma
-good-names=i,
-           j,
-           k,
-           ex,
-           Run,
-           _
-
-# Include a hint for the correct naming format with invalid-name
-include-naming-hint=no
-
-# Naming style matching correct inline iteration names
-inlinevar-naming-style=any
-
-# Regular expression matching correct inline iteration names. Overrides
-# inlinevar-naming-style
-#inlinevar-rgx=
-
-# Naming style matching correct method names
-method-naming-style=snake_case
-
-# Regular expression matching correct method names. Overrides method-naming-
-# style
-#method-rgx=
-
-# Naming style matching correct module names
-module-naming-style=snake_case
-
-# Regular expression matching correct module names. Overrides module-naming-
-# style
-#module-rgx=
-
-# Colon-delimited sets of names that determine each other's naming style when
-# the name regexes allow several styles.
-name-group=
-
-# Regular expression which should only match function or class names that do
-# not require a docstring.
-no-docstring-rgx=^_
-
-# List of decorators that produce properties, such as abc.abstractproperty. Add
-# to this list to register other decorators that produce valid properties.
-property-classes=abc.abstractproperty
-
-# Naming style matching correct variable names
-variable-naming-style=snake_case
-
-# Regular expression matching correct variable names. Overrides variable-
-# naming-style
-#variable-rgx=
-
-
-[DESIGN]
-
-# Maximum number of arguments for function / method
-max-args=5
-
-# Maximum number of attributes for a class (see R0902).
-max-attributes=7
-
-# Maximum number of boolean expressions in a if statement
-max-bool-expr=5
-
-# Maximum number of branch for function / method body
-max-branches=12
-
-# Maximum number of locals for function / method body
-max-locals=15
-
-# Maximum number of parents for a class (see R0901).
-max-parents=7
-
-# Maximum number of public methods for a class (see R0904).
-max-public-methods=20
-
-# Maximum number of return / yield for function / method body
-max-returns=6
-
-# Maximum number of statements in function / method body
-max-statements=50
-
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=2
-
-
-[CLASSES]
-
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,
-                      __new__,
-                      setUp
-
-# List of member names, which should be excluded from the protected access
-# warning.
-exclude-protected=_asdict,
-                  _fields,
-                  _replace,
-                  _source,
-                  _make
-
-# List of valid names for the first argument in a class method.
-valid-classmethod-first-arg=cls
-
-# List of valid names for the first argument in a metaclass class method.
-valid-metaclass-classmethod-first-arg=mcs
-
-
-[IMPORTS]
-
-# Allow wildcard imports from modules that define __all__.
-allow-wildcard-with-all=no
-
-# Analyse import fallback blocks. This can be used to support both Python 2 and
-# 3 compatible code, which means that the block might have code that exists
-# only in one or another interpreter, leading to false positives when analysed.
-analyse-fallback-blocks=no
-
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules=regsub,
-                   TERMIOS,
-                   Bastion,
-                   rexec
-
-# Create a graph of external dependencies in the given file (report RP0402 must
-# not be disabled)
-ext-import-graph=
-
-# Create a graph of every (i.e. internal and external) dependencies in the
-# given file (report RP0402 must not be disabled)
-import-graph=
-
-# Create a graph of internal dependencies in the given file (report RP0402 must
-# not be disabled)
-int-import-graph=
-
-# Force import order to recognize a module as part of the standard
-# compatibility libraries.
-known-standard-library=
-
-# Force import order to recognize a module as part of a third party library.
-known-third-party=enchant
-
-[STRING]
-
-# This flag controls whether the implicit-str-concat should generate a warning
-# on implicit string concatenation in sequences defined over several lines.
-check-str-concat-over-line-jumps=yes
-
-[EXCEPTIONS]
-
-# Exceptions that will emit a warning when being caught. Defaults to
-# "Exception"
-overgeneral-exceptions=builtins.Exception

+ 25 - 5
README.md

@@ -121,14 +121,14 @@ exo
 
 That's it! No configuration required - exo will automatically discover the other device(s).
 
-exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000
+exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:52415
 
-For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Examples with curl:
+For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:52415/v1/chat/completions. Examples with curl:
 
 #### Llama 3.2 3B:
 
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
      "model": "llama-3.2-3b",
@@ -140,7 +140,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llama 3.1 405B:
 
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
      "model": "llama-3.1-405b",
@@ -152,7 +152,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llava 1.5 7B (Vision Language Model):
 
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
      "model": "llava-1.5-7b-hf",
@@ -208,6 +208,12 @@ With a custom prompt:
 exo run llama-3.2-3b --prompt "What is the meaning of exo?"
 ```
 
+### Model Storage
+
+Models by default are stored in `~/.cache/huggingface/hub`.
+
+You can set a different model storage location by setting the `HF_HOME` env var.
+
 ## Debugging
 
 Enable debug logs with the DEBUG environment variable (0-9).
@@ -222,6 +228,20 @@ For the **tinygrad** inference engine specifically, there is a separate DEBUG fl
 TINYGRAD_DEBUG=2 exo
 ```
 
+## Formatting
+
+We use [yapf](https://github.com/google/yapf) to format the code. To format the code, first install the formatting requirements:
+
+```sh
+pip3 install -e '.[formatting]'
+```
+
+Then run the formatting script:
+
+```sh
+python3 format.py ./exo
+```
+
 ## Known Issues
 
 - On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:

+ 18 - 2
configure_mlx.sh

@@ -1,2 +1,18 @@
-sudo sysctl iogpu.wired_lwm_mb=400000
-sudo sysctl iogpu.wired_limit_mb=180000
+#!/bin/bash
+
+# Get the total memory in MB
+TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
+
+# Set WIRED_LIMIT_MB to 80%
+WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
+# Set  WIRED_LWM_MB to 70%
+WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
+
+# Display the calculated values
+echo "Total memory: $TOTAL_MEM_MB MB"
+echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
+echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
+
+# Apply the values with sysctl
+sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
+sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB

BIN
docs/exo-rounded.png


+ 1 - 1
examples/astra/astra/ContentView.swift

@@ -148,7 +148,7 @@ struct ContentView: View {
     @State private var voiceActivityThreshold: Float = 0.40
     @State private var silenceTimeThreshold = 1.0
     @State private var debugText = ""
-    @State private var apiEndpoint = "http://192.168.212.74:8000/v1/chat/completions"
+    @State private var apiEndpoint = "http://192.168.212.74:52415/v1/chat/completions"
     @State private var audioBuffer: [Float] = []
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var isInitialTranscription = true

+ 1 - 1
examples/chatgpt_api.sh

@@ -3,7 +3,7 @@
 # This works the same in a single-node set up and in a multi-node setup.
 # You need to start exo before running this by running `python3 main.py`.
 
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 MODEL="llama-3.1-8b"
 PROMPT="What is the meaning of exo?"
 TEMPERATURE=0.7

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 53 - 41
exo/api/chatgpt_api.py

@@ -8,15 +8,14 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 import aiohttp_cors
 import traceback
+import signal
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
-from exo.inference.shard import Shard
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import model_base_shards
-from typing import Callable
-
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
+from typing import Callable, Optional
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -27,6 +26,7 @@ class Message:
     return {"role": self.role, "content": self.content}
 
 
+
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
@@ -117,19 +117,11 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  image_str = None
   for message in messages:
     if not isinstance(message.content, list):
       continue
 
-    for content in message.content:
-      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
-      # follows the convention in https://platform.openai.com/docs/guides/vision
-      if isinstance(content, dict) and content.get("type", None) == "image":
-        image_str = content.get("image", None)
-        break
-
-  return prompt, image_str
+  return prompt
 
 
 def parse_message(data: dict):
@@ -138,9 +130,9 @@ def parse_message(data: dict):
   return Message(data["role"], data["content"])
 
 
-def parse_chat_request(data: dict):
+def parse_chat_request(data: dict, default_model: str):
   return ChatCompletionRequest(
-    data.get("model", "llama-3.1-8b"),
+    data.get("model", default_model),
     [parse_message(msg) for msg in data["messages"]],
     data.get("temperature", 0.0),
   )
@@ -152,9 +144,8 @@ class PromptSession:
     self.timestamp = timestamp
     self.prompt = prompt
 
-
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
@@ -163,6 +154,8 @@ class ChatGPTAPI:
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
+    self.default_model = default_model or "llama-3.2-1b"
+
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
       allow_credentials=True,
@@ -177,13 +170,24 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
+    cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
+    cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
-    self.static_dir = Path(__file__).parent.parent / "tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    if DEBUG>=1: print("Received quit signal")
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+    await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
 
   async def timeout_middleware(self, app, handler):
     async def middleware(request):
@@ -191,6 +195,7 @@ class ChatGPTAPI:
         return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
       except asyncio.TimeoutError:
         return web.json_response({"detail": "Request timed out"}, status=408)
+
     return middleware
 
   async def log_request(self, app, handler):
@@ -203,14 +208,25 @@ class ChatGPTAPI:
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
 
+  async def handle_healthcheck(self, request):
+    return web.json_response({"status": "ok"})
+
+  async def handle_model_support(self, request):
+    return web.json_response({
+      "model pool": {
+        model_name: pretty_name.get(model_name, model_name) 
+        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
+      }
+    })
+  
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True } for model_name, _ in model_base_shards.items()])
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
-    shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    shard = build_base_shard(self.default_model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
 
   async def handle_get_download_progress(self, request):
@@ -222,29 +238,28 @@ class ChatGPTAPI:
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
     return web.json_response(progress_data)
 
-
   async def handle_post_chat_completions(self, request):
     data = await request.json()
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
-    chat_request = parse_chat_request(data)
-    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-      chat_request.model = "llama-3.1-8b"
-    if not chat_request.model or chat_request.model not in model_base_shards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
-      chat_request.model = "llama-3.1-8b"
-    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
+    chat_request = parse_chat_request(data, self.default_model)
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
+      chat_request.model = self.default_model
+    if not chat_request.model or chat_request.model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      chat_request.model = self.default_model
+    shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     if not shard:
-      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
       )
 
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-    prompt, image_str = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
       try:
@@ -267,13 +282,10 @@ class ChatGPTAPI:
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
 
     try:
-      await asyncio.wait_for(
-        asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))),
-        timeout=self.response_timeout
-      )
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
 
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 
@@ -356,7 +368,7 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
-  async def run(self, host: str = "0.0.0.0", port: int = 8000):
+  async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()
     site = web.TCPSite(runner, host, port)

+ 34 - 6
exo/download/hf/hf_helpers.py

@@ -1,7 +1,11 @@
+import aiofiles.os as aios
+from typing import Union
 import asyncio
 import aiohttp
 import json
 import os
+import sys
+import shutil
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
@@ -9,7 +13,7 @@ from fnmatch import fnmatch
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 import aiofiles
@@ -17,7 +21,6 @@ from aiofiles import os as aios
 
 T = TypeVar("T")
 
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
@@ -70,8 +73,10 @@ def _add_wildcard_to_directories(pattern: str) -> str:
     return pattern + "*"
   return pattern
 
+
 def get_hf_endpoint() -> str:
-    return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+  return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+
 
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
@@ -97,9 +102,22 @@ async def get_auth_headers():
 
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
+async def move_models_to_hf(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)
+  async for path in source_dir.iterdir():
+    if path.is_dir() and path.startswith("models--"):
+      dest_path = dest_dir / path.name
+      if dest_path.exists():
+        if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
+      else:
+        await aios.rename(str(path), str(dest_path))
+        
 
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
@@ -394,7 +412,7 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-  default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"])
+  default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   shard_specific_patterns = set()
   if weight_map:
     for tensor_name, filename in weight_map.items():
@@ -407,6 +425,16 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     elif shard.is_last_layer():
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
-    shard_specific_patterns = set("*.safetensors")
+    shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
+
+async def has_hf_home_read_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.R_OK)
+  except OSError: return False
+
+async def has_hf_home_write_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.W_OK)
+  except OSError: return False

+ 8 - 6
exo/download/hf/hf_shard_download.py

@@ -7,6 +7,7 @@ from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import model_cards, get_repo
 
 
 class HFShardDownloader(ShardDownloader):
@@ -17,11 +18,12 @@ class HFShardDownloader(ShardDownloader):
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
     if self.quick_check:
-      repo_root = get_repo_root(shard.model_id)
+      repo_root = get_repo_root(repo_name)
       snapshots_dir = repo_root/"snapshots"
       if snapshots_dir.exists():
         visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
@@ -51,7 +53,7 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
     # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard))
+    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
     self.active_downloads[shard] = download_task
     try:
       path = await download_task
@@ -63,14 +65,14 @@ class HFShardDownloader(ShardDownloader):
       if shard in self.active_downloads:
         self.active_downloads.pop(shard)
 
-  async def _download_shard(self, shard: Shard) -> Path:
+  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
     async def wrapped_progress_callback(event: RepoProgressEvent):
       self._on_progress.trigger_all(shard, event)
 
-    weight_map = await get_weight_map(shard.model_id)
+    weight_map = await get_weight_map(repo_name)
     allow_patterns = get_allow_patterns(weight_map, shard)
 
-    return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
+    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
 
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:

+ 4 - 2
exo/download/shard_download.py

@@ -8,7 +8,7 @@ from exo.helpers import AsyncCallbackSystem
 
 class ShardDownloader(ABC):
   @abstractmethod
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     """
         Ensures that the shard is downloaded.
         Does not allow multiple overlapping downloads at once.
@@ -17,6 +17,7 @@ class ShardDownloader(ABC):
 
         Args:
             shard (Shard): The shard to download.
+            inference_engine_name (str): The inference engine used on the node hosting the shard
         """
     pass
 
@@ -25,8 +26,9 @@ class ShardDownloader(ABC):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
 
+
 class NoopShardDownloader(ShardDownloader):
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     return Path("/tmp/noop_shard")
 
   @property

+ 21 - 1
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
+import sys
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
@@ -170,7 +171,7 @@ def is_valid_uuid(val):
 
 
 def get_or_create_node_id():
-  NODE_ID_FILE = Path(tempfile.gettempdir()) / ".exo_node_id"
+  NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
   try:
     if NODE_ID_FILE.is_file():
       with open(NODE_ID_FILE, "r") as f:
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
+
+
+async def shutdown(signal, loop, server):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 11 - 12
exo/inference/debug_inference_engine.py

@@ -13,32 +13,31 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
 
   prompt = "In a single word only, what is the last name of the president of the United States? "
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
-    input_data=resp_full,
-    inference_state=inference_state_full,
+    input_data=token_full,
   )
 
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp1,
-    inference_state=inference_state_1,
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  token2 = await inference_engine_2.sample(resp2)
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
-    input_data=resp2,
-    inference_state=inference_state_2,
+    input_data=token2,
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,
-    inference_state=inference_state_3,
   )
 
   print(f"{resp2=}")

+ 17 - 38
exo/inference/dummy_inference_engine.py

@@ -1,59 +1,38 @@
 from typing import Optional, Tuple, TYPE_CHECKING
 import numpy as np
+import random
+import string
 import asyncio
 import json
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
+def random_string(length: int):
+  return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
+  
 
 class DummyInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None
     self.vocab_size = 1000
+    self.hidden_size = 256
     self.eos_token_id = 0
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-    try:
-      await self.ensure_shard(shard)
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
+  
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size)
 
-      # Generate random tokens
-      output_length = np.random.randint(1, 10)
-      output = np.random.randint(1, self.vocab_size, size=(1, output_length))
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
+    return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
 
-      # Simulate latency
-      await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-
-      # Randomly decide if finished
-      is_finished = np.random.random() < 0.2
-      if is_finished:
-        output = np.array([[self.eos_token_id]])
-
-      new_state = json.dumps({"dummy_state": "some_value"})
-
-      return output, new_state, is_finished
-    except Exception as e:
-      print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
-      return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
-
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    state = json.loads(inference_state or "{}")
-    start_pos = state.get("start_pos", 0)
-
-    output_length = np.random.randint(1, 10)
-    output = np.random.randint(1, self.vocab_size, size=(1, output_length))
-
-    await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-
-    is_finished = np.random.random() < 0.2
-    if is_finished:
-      output = np.array([[self.eos_token_id]])
-
-    start_pos += input_data.shape[1] + output_length
-    new_state = json.dumps({"start_pos": start_pos})
-
-    return output, new_state, is_finished
+    sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
+    output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))
+    return output
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 22 - 3
exo/inference/inference_engine.py

@@ -9,13 +9,32 @@ from .shard import Shard
 
 class InferenceEngine(ABC):
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    pass
+  
+  @abstractmethod
+  async def sample(self, x: np.ndarray) -> np.ndarray:
     pass
 
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     pass
 
+  @abstractmethod
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    pass
+  
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
+    tokens = await self.encode(shard, prompt)
+    x = tokens.reshape(1, -1)
+    output_data = await self.infer_tensor(request_id, shard, x)
+    return output_data 
+
+inference_engine_classes = {
+  "mlx": "MLXDynamicShardInferenceEngine",
+  "tinygrad": "TinygradDynamicShardInferenceEngine",
+  "dummy": "DummyInferenceEngine",
+}
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
   if DEBUG >= 2:
@@ -33,4 +52,4 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
   elif inference_engine_name == "dummy":
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
     return DummyInferenceEngine()
-  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 1 - 1
exo/inference/mlx/models/base.py

@@ -1,7 +1,7 @@
 from typing import Optional
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 
 
 class IdentityBlock(nn.Module):

+ 1 - 1
exo/inference/mlx/models/deepseek_v2.py

@@ -4,7 +4,7 @@ from typing import Optional
 import mlx.core as mx
 import mlx.nn as nn
 
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from .base import IdentityBlock
 from exo.inference.shard import Shard

+ 118 - 0
exo/inference/mlx/models/gemma2.py

@@ -0,0 +1,118 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.gemma2 import TransformerBlock, ModelArgs, RMSNorm
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class GemmaModel(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if args.shard.is_first_layer() or args.shard.is_last_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if args.shard.start_layer <= i <= args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+    if args.shard.is_last_layer():
+      self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+      h = h * (self.args.hidden_size**0.5)
+    else:
+      h = inputs
+
+    mask = None
+    if h.ndim > 1 and h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = GemmaModel(args)
+    if args.shard.is_last_layer():
+      self.final_logit_softcapping = args.final_logit_softcapping
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      out = self.model.embed_tokens.as_linear(out)
+      out = mx.tanh(out / self.final_logit_softcapping)
+      out = out * self.final_logit_softcapping
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif (self.args.shard.is_first_layer() or self.args.shard.is_last_layer()) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.head_dim
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 2 - 1
exo/inference/mlx/models/qwen2.py

@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
 
     self.shard = Shard(**self.shard)
 
+
 class Qwen2Model(nn.Module):
   def __init__(self, args: ModelArgs):
     super().__init__()
@@ -57,7 +58,7 @@ class Qwen2Model(nn.Module):
       mask = create_attention_mask(h, cache)
 
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, c)

+ 46 - 21
exo/inference/mlx/sharded_inference_engine.py

@@ -1,14 +1,35 @@
 import numpy as np
 import mlx.core as mx
+import mlx.nn as nn
 from ..inference_engine import InferenceEngine
-from .sharded_model import StatefulShardedModel
+from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
-from typing import Optional
+from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
+def sample_logits(
+  logits: mx.array,
+  temp: float = 0.0,
+  top_p: float = 1.0,
+  logit_bias: Optional[Dict[int, float]] = None
+) -> Tuple[mx.array, float]:
+  if logit_bias:
+    indices = mx.array(list(logit_bias.keys()))
+    values = mx.array(list(logit_bias.values()))
+    logits[:, indices] += values
+
+  if temp == 0:
+    token = mx.argmax(logits, axis=-1)
+  else:
+    if top_p > 0 and top_p < 1.0:
+      token = top_p_sampling(logits, top_p, temp)
+    else:
+      token = mx.random.categorical(logits*(1/temp))
+
+  return token
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -16,35 +37,39 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
+    y = mx.array(x)
+    logits = y[:, -1, :]
+    out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
+    return out
+
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
-    if image_str:
-      image = await get_image_from_str(image_str)
-      tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
-      inputs = await loop.run_in_executor(self.executor, tokenize)
-      pixel_values = mx.array(inputs["pixel_values"])
-      input_ids = mx.array(inputs["input_ids"])
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
-    else:
-      input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return np.array(tokens)
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
+    
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
+    return output_data
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
       return
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
       loop = asyncio.get_running_loop()
-      def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
+
+      def load_shard_wrapper():
+        return asyncio.run(load_shard(model_path, shard))
+
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
-      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 0 - 86
exo/inference/mlx/sharded_model.py

@@ -1,86 +0,0 @@
-from typing import Dict, Generator, Optional, Tuple
-from collections import OrderedDict
-
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.models.base import KVCache, RotatingKVCache
-from mlx_lm.sample_utils import top_p_sampling
-
-from ..shard import Shard
-
-# TODO: support a speculative model so we can parallelise compute across devices
-class StatefulShardedModel:
-  def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
-    self.shard = shard
-    self.model = model
-    self.max_kv_size = max_kv_size
-    self.max_caches = max_caches
-    self.caches = OrderedDict()
-
-  def step(
-    self,
-    request_id: str,
-    x,
-    pixel_values=None,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    def sample(logits: mx.array) -> Tuple[mx.array, float]:
-      if logit_bias:
-        indices = mx.array(list(logit_bias.keys()))
-        values = mx.array(list(logit_bias.values()))
-        logits[:, indices] += values
-
-      if temp == 0:
-        token = mx.argmax(logits, axis=-1)
-      else:
-        if top_p > 0 and top_p < 1.0:
-          token = top_p_sampling(logits, top_p, temp)
-        else:
-          token = mx.random.categorical(logits*(1/temp))
-
-      return token
-
-    y = x
-
-    if request_id not in self.caches:
-      self.init_cache(request_id)
-    else:
-      self.caches.move_to_end(request_id)
-
-    cache = self.caches[request_id]
-
-    if pixel_values is None:
-      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
-    else:
-      output = self.model(y, pixel_values=pixel_values, cache=cache)
-
-    if self.shard.is_last_layer():
-      logits = output[:, -1, :]
-      y = sample(logits)
-      return y
-    else:
-      return output
-
-  def __call__(
-    self,
-    request_id: str,
-    x,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
-
-  def init_cache(self, request_id: str):
-    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
-    if self.max_kv_size is not None:
-      cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
-    else:
-      cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
-
-    self.caches[request_id] = cache

+ 14 - 8
exo/inference/mlx/sharded_utils.py

@@ -12,15 +12,16 @@ from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from io import BytesIO
 import base64
+import traceback
 
 import mlx.core as mx
 import mlx.nn as nn
 from transformers import AutoProcessor
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
-from mlx_lm.tuner.utils import apply_lora_layers
 
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 
 
@@ -53,6 +54,7 @@ def _get_classes(config: dict):
   except ImportError:
     msg = f"Model type {model_type} not supported."
     logging.error(msg)
+    traceback.print_exc()
     raise ValueError(msg)
 
   return arch.Model, arch.ModelArgs
@@ -67,7 +69,6 @@ def load_config(model_path: Path) -> dict:
     raise
   return config
 
-
 def load_model_shard(
   model_path: Path,
   shard: Shard,
@@ -130,8 +131,17 @@ def load_model_shard(
 
   model_class, model_args_class = _get_classes(config=config)
 
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x, *args, **kwargs)
+      return y
+
   model_args = model_args_class.from_dict(config)
-  model = model_class(model_args)
+  model = ShardedModel(model_args)
 
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
@@ -157,7 +167,6 @@ def load_model_shard(
   model.eval()
   return model
 
-
 async def load_shard(
   model_path: str,
   shard: Shard,
@@ -167,9 +176,6 @@ async def load_shard(
   lazy: bool = False,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
   model = load_model_shard(model_path, shard, lazy, model_config)
-  if adapter_path is not None:
-    model = apply_lora_layers(model, adapter_path)
-    model.eval()
 
   # TODO: figure out a generic solution
   if model.model_type == "llava":
@@ -178,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     return model, processor
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
 
 

+ 42 - 0
exo/inference/mlx/stateful_model.py

@@ -0,0 +1,42 @@
+from typing import Dict, Tuple
+from collections import OrderedDict
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.cache import make_prompt_cache
+
+from ..shard import Shard
+
+class StatefulModel(nn.Module):
+  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_kv_size = max_kv_size
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
+  
+  def init_cache(self, request_id: str):
+    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
+    # if self.max_kv_size is not None:
+      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    # else:
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    cache = make_prompt_cache(self.model)
+
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, x, request_id: str):
+    if request_id not in self.caches:
+      self.init_cache(request_id)
+    else:
+      self.caches.move_to_end(request_id)
+
+    cache = self.caches[request_id]
+
+    y = self.model(x, cache=cache)
+    return y
+    

+ 4 - 4
exo/inference/mlx/test_sharded_llama.py

@@ -1,5 +1,5 @@
 import mlx.core as mx
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 
@@ -12,9 +12,9 @@ full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Ins
 model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
 model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
 
-full = StatefulShardedModel(shard_full, full_model_shard)
-m1 = StatefulShardedModel(shard1, model_shard1)
-m2 = StatefulShardedModel(shard2, model_shard2)
+full = StatefulModel(shard_full, full_model_shard)
+m1 = StatefulModel(shard1, model_shard1)
+m2 = StatefulModel(shard2, model_shard2)
 
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt_tokens = mx.array(full_tokenizer.encode(prompt))

+ 2 - 2
exo/inference/mlx/test_sharded_llava.py

@@ -5,9 +5,9 @@ from PIL import Image
 from io import BytesIO
 
 import mlx.core as mx
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 

+ 39 - 42
exo/inference/test_dummy_inference_engine.py

@@ -4,53 +4,50 @@ import numpy as np
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.shard import Shard
 
+
 class MockShardDownloader:
-    async def ensure_shard(self, shard):
-        pass
+  async def ensure_shard(self, shard):
+    pass
+
+
 @pytest.mark.asyncio
 async def test_dummy_inference_specific():
-    engine = DummyInferenceEngine(MockShardDownloader())
-    test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    test_prompt = "This is a test prompt"
-    
-    result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
-    
-    print(f"Inference result shape: {result.shape}")
-    print(f"Inference state: {state}")
-    print(f"Is finished: {is_finished}")
-    
-    assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
-    assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+  engine = DummyInferenceEngine(MockShardDownloader())
+  test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+  test_prompt = "This is a test prompt"
+
+  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
+
+  print(f"Inference result shape: {result.shape}")
+
+  assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
+
 
 @pytest.mark.asyncio
 async def test_dummy_inference_engine():
-    # Initialize the DummyInferenceEngine
-    engine = DummyInferenceEngine(MockShardDownloader())
-    
-    # Create a test shard
-    shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    
-    # Test infer_prompt
-    output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
-    
-    assert isinstance(output, np.ndarray), "Output should be a numpy array"
-    assert output.ndim == 2, "Output should be 2-dimensional"
-    assert isinstance(state, str), "State should be a string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
-
-    # Test infer_tensor
-    input_tensor = np.array([[1, 2, 3]])
-    output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
-    
-    assert isinstance(output, np.ndarray), "Output should be a numpy array"
-    assert output.ndim == 2, "Output should be 2-dimensional"
-    assert isinstance(state, str), "State should be a string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
-
-    print("All tests passed!")
+  # Initialize the DummyInferenceEngine
+  engine = DummyInferenceEngine(MockShardDownloader())
+
+  # Create a test shard
+  shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+
+  # Test infer_prompt
+  output = await engine.infer_prompt("test_id", shard, "Test prompt")
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+
+  # Test infer_tensor
+  input_tensor = np.array([[1, 2, 3]])
+  output = await engine.infer_tensor("test_id", shard, input_tensor)
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+
+  print("All tests passed!")
+
 
 if __name__ == "__main__":
-    import asyncio
-    asyncio.run(test_dummy_inference_engine())
-    asyncio.run(test_dummy_inference_specific())
+  import asyncio
+  asyncio.run(test_dummy_inference_engine())
+  asyncio.run(test_dummy_inference_specific())

+ 14 - 24
exo/inference/test_inference_engine.py

@@ -11,45 +11,40 @@ import numpy as np
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+  token_full = token_full.reshape(1, -1)
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
-    input_data=resp_full,
-    inference_state=inference_state_full,
+    input_data=token_full,
   )
 
   pp = n_layers // 2
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
-    inference_state=inference_state_1,
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  tokens2 = await inference_engine_1.sample(resp2)
+  tokens2 = tokens2.reshape(1, -1)
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
-    input_data=resp2,
-    inference_state=inference_state_2,
+    input_data=tokens2,
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,
-    inference_state=inference_state_3,
   )
 
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(test_inference_engine(
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  "mlx-community/Llama-3.2-1B-Instruct-4bit",
-  16
-))
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
@@ -57,10 +52,5 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   asyncio.run(
-    test_inference_engine(
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-      32
-    )
+    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
   )

+ 34 - 36
exo/inference/tinygrad/inference.py

@@ -1,17 +1,17 @@
 from pathlib import Path
 import json
 import os
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
 from tinygrad import Tensor, nn, Context
 from exo.inference.inference_engine import InferenceEngine
-from typing import Optional, Tuple
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
+from .stateful_model import StatefulModel
 import asyncio
 
 Tensor.no_grad = True
@@ -22,7 +22,17 @@ TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 MODEL_PARAMS = {
-  "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
+  "1B": {
+    "args": {
+      "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "3B": {
+    "args": {
+      "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
   "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
 }
 
@@ -30,8 +40,7 @@ MODEL_PARAMS = {
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   linear = nn.Linear
-  with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
+  model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
   # load weights
   if model_path.is_dir():
@@ -48,54 +57,43 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
   return model
 
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
-
-    if h.shape == (1,):
-      start_pos += len(toks)
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      n_captured_toks = len(toks)
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+  async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
+    logits = x[:, -1, :]
+    def sample_wrapper():
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
+  
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
 
-    if h.shape == (1,):
-      start_pos += n_captured_toks
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    await self.ensure_shard(shard)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
       return
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
-      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
+      loop = asyncio.get_running_loop()
+      parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
+      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
 
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 61 - 36
exo/inference/tinygrad/models/llama.py

@@ -1,11 +1,23 @@
-from typing import Tuple, Union, Optional, Dict, Any
+from typing import Tuple, Union, Optional, Dict, Any, List
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
+from collections import OrderedDict
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
   freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
+
+  if rope_scaling:
+    factor = rope_scaling.get('factor', 1.0)
+    low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
+    high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
+    original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
+
+    freqs[:dim // 4] *= low_freq_factor
+    freqs[dim // 4:] = freqs[dim // 4:].contiguous()*high_freq_factor
+    freqs *= (original_max_pos_emb/end)**(1.0/factor)
+
   freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
@@ -36,7 +48,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
 
-
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
@@ -50,7 +61,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
     if getenv("WQKV"):
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       xqkv = x @ self.wqkv.T
@@ -65,19 +76,16 @@ class Attention:
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     bsz, seqlen, _, _ = xq.shape
 
-    # create kv cache
-    if not hasattr(self, "cache_kv"):
-      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
-      if isinstance(x.device, tuple):
-        # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+    if cache is not None:
+      # update the cache
+      assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
+      cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
-    # update the cache
-    assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
-
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+      keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+      values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    else:
+      keys = xk
+      values = xv
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -103,13 +111,13 @@ class TransformerBlock:
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
-    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
+    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
 # standard openai sampling
-def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
+def sample_logits(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert 0 <= p <= 1, "p must be between 0 and 1"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
@@ -176,42 +184,56 @@ class Transformer:
     rope_theta=10000,
     max_context=1024,
     jit=True,
-    feed_forward=FeedForward
+    feed_forward=FeedForward,
+    rope_scaling: Optional[Dict[str, float]] = None,
+    tie_word_embeddings=False,
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
+    if tie_word_embeddings:
+      self.output.weight = self.tok_embeddings.weight
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
-    self.forward_jit = TinyJit(self.forward) if jit else None
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
+    self.forward_jit = TinyJit(self.forward_base) if jit else None
     self.shard = shard
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
+  def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
     seqlen = x.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
-    if self.shard.is_first_layer():
-      h = self.tok_embeddings(x)
-    else:
-      h = x
+    h = x
 
-    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+    if cache is None:
+      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
+    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
       layer = self.layers[i]
-      h = layer(h, start_pos, freqs_cis, mask)
+      h = layer(h, start_pos, freqs_cis, mask, cache=c)
 
     if self.shard.is_last_layer():
-      logits = self.output(self.norm(h)).float()[:, -1, :]
-      return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+      logits = self.output(self.norm(h)).float().realize()
+      return logits
     else:
       return h
 
-  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
+  def embed(self, inputs: Tensor):
+    if self.shard.is_first_layer():
+      h = self.tok_embeddings(inputs)
+    else:
+      h = inputs
+    return h
+
+  def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
+    if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
+      return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
+    return self.forward_base(x, start_pos, cache=cache)
+
+  def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
     # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
-    return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
+    h = self.embed(x)
+    return self.forward(h, start_pos, cache=cache)
 
 
 # *** helpers ***
@@ -245,7 +267,10 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
         v = permute(v, n_heads)
       elif "k_proj" in k:
         v = permute(v, n_kv_heads)
-    sd[keymap[k]] = v
+    if k in keymap:
+      sd[keymap[k]] = v
+    else:
+      sd[k] = v
   return sd
 
 

+ 42 - 0
exo/inference/tinygrad/stateful_model.py

@@ -0,0 +1,42 @@
+from tinygrad import Tensor, Variable 
+from collections import OrderedDict
+from typing import List
+
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
+  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+  if isinstance(x.device, tuple):
+    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+  return cache_kv.realize()
+
+class ModelState:
+  cache: List[Tensor]
+  start: int 
+  def __init__(self, cache: List[Tensor], start: int = 0):
+    self.cache = cache
+    self.start = start
+
+class StatefulModel:
+  def __init__(self, model, max_states: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_states = max_states
+    self.states = OrderedDict()
+ 
+  def init_cache(self, x: Tensor, request_id: str):
+    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
+    if len(self.states) >= self.max_states:
+      self.states.popitem(last=False)
+
+    self.states[request_id] = ModelState(cache)
+
+  def __call__(self, x: Tensor, request_id: str): 
+    h = self.model.embed(x)
+    if request_id not in self.states:
+      self.init_cache(h, request_id)
+    else:
+      self.states.move_to_end(request_id)
+    out = self.model.forward(h, self.states[request_id].start, cache=self.states[request_id].cache)
+    self.states[request_id].start += h.shape[1]
+    return out
+

+ 6 - 1
exo/inference/tokenizers.py

@@ -7,14 +7,18 @@ from transformers import AutoTokenizer, AutoProcessor
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 
+
 class DummyTokenizer:
   def __init__(self):
     self.eos_token_id = 0
+
   def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
-    return [1,2,3]
+    return [1, 2, 3]
+
   def decode(self, tokens):
     return "dummy"
 
+
 async def resolve_tokenizer(model_id: str):
   if model_id == "dummy":
     return DummyTokenizer()
@@ -29,6 +33,7 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 5: traceback.print_exc()
   return await _resolve_tokenizer(model_id)
 
+
 async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")

+ 76 - 40
exo/main.py

@@ -1,8 +1,12 @@
 import argparse
 import asyncio
+import atexit
 import signal
 import json
 import logging
+import platform
+import os
+import sys
 import time
 import traceback
 import uuid
@@ -17,22 +21,24 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
-from exo.models import model_base_shards
+from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
+parser.add_argument("--default-model", type=str, default=None, help="Default model")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
@@ -42,7 +48,7 @@ parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale",
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
-parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
+parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=900, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
@@ -54,14 +60,13 @@ parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 
-
 print_yellow_exo()
 
-
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
+                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
@@ -84,9 +89,23 @@ if DEBUG >= 0:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
 if args.discovery_module == "udp":
-  discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
+  discovery = UDPDiscovery(
+    args.node_id,
+    args.node_port,
+    args.listen_port,
+    args.broadcast_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout
+  )
 elif args.discovery_module == "tailscale":
-  discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
+  discovery = TailscaleDiscovery(
+    args.node_id,
+    args.node_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout,
+    tailscale_api_key=args.tailscale_api_key,
+    tailnet=args.tailnet_name
+  )
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
@@ -108,22 +127,26 @@ api = ChatGPTAPI(
   node,
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
-  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
+  default_model=args.default_model
 )
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
+
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
     status = json.loads(opaque_status)
     if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
       current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
       if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-      asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
   except Exception as e:
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
       traceback.print_exc()
+
+
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 
 if args.prometheus_client_port:
@@ -132,38 +155,24 @@ if args.prometheus_client_port:
 
 last_broadcast_time = 0
 
-def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-    global last_broadcast_time
-    current_time = time.time()
-    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-        last_broadcast_time = current_time
-        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
-            "type": "download_progress",
-            "node_id": node.id,
-            "progress": event.to_dict()
-        })))
 
-shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
+def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
 
 
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
+shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  inference_class = inference_engine.__class__.__name__
+  shard = build_base_shard(model_name, inference_class)
   if not shard:
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
-  tokenizer = await resolve_tokenizer(shard.model_id)
+  tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   request_id = str(uuid.uuid4())
   callback_id = f"cli-wait-response-{request_id}"
   callback = node.on_token.register(callback_id)
@@ -173,7 +182,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 
   try:
     print(f"Processing prompt: {prompt}")
-    await node.process_prompt(shard, prompt, None, request_id=request_id)
+    await node.process_prompt(shard, prompt, request_id=request_id)
 
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
 
@@ -189,12 +198,38 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 async def main():
   loop = asyncio.get_running_loop()
 
+  # Check HuggingFace directory permissions
+  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
+  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  print(f"{has_read=}, {has_write=}")
+  if not has_read or not has_write:
+    print(f"""
+          WARNING: Limited permissions for model storage directory: {hf_home}.
+          This may prevent model downloads from working correctly.
+          {"❌ No read access" if not has_read else ""}
+          {"❌ No write access" if not has_write else ""}
+          """)
+    
+  if not args.models_seed_dir is None:
+    try:
+      await move_models_to_hf(args.models_seed_dir)
+    except Exception as e:
+      print(f"Error moving models to .cache/huggingface: {e}")
+
+  def restore_cursor():
+    if platform.system() != "Windows":
+        os.system("tput cnorm")  # Show cursor
+
+  # Restore the cursor when the program exits
+  atexit.register(restore_cursor)
+
   # Use a more direct approach to handle signals
   def handle_exit():
-    asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
+    asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server))
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
   await node.start(wait_for_peers=args.wait_for_peers)
 
@@ -217,8 +252,9 @@ def run():
   except KeyboardInterrupt:
     print("Received keyboard interrupt. Shutting down...")
   finally:
-    loop.run_until_complete(shutdown(signal.SIGTERM, loop))
+    loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
     loop.close()
 
+
 if __name__ == "__main__":
   run()

+ 125 - 50
exo/models.py

@@ -1,73 +1,148 @@
 from exo.inference.shard import Shard
+from typing import Optional, List
 
-model_base_shards = {
+model_cards = {
   ### llama
   "llama-3.2-1b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
+    "layers": 16,
+    "repo": {
+      "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
+      "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+    },
   },
   "llama-3.2-3b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
+    "layers": 28,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+    },
   },
   "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
+    },
   },
   "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
   "llama-3.1-70b-bf16": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
-  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+    },
   },
   "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-  },
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
+    },
+  },
+  "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
+  "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
   ### mistral
-  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
-  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
+  "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
+  "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
   ### deepseek
-  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
-  "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),},
+  "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
+  "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
   ### llava
-  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
+  "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
-  "qwen-2.5-coder-1.5b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-coder-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-math-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-14b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),
-  },
-  "qwen-2.5-72b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "qwen-2.5-math-72b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
+  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
+  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
+  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
+  "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
+  "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
+  "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
+  "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
+  "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
   ### nemotron
-  "nemotron-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "nemotron-70b-bf16": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
+  "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
+  # gemma
+  "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
+  "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
   # dummy
-  "dummy": {
-    "DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),
-  },
+  "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
+}
+
+pretty_name = {
+  "llama-3.2-1b": "Llama 3.2 1B",
+  "llama-3.2-3b": "Llama 3.2 3B",
+  "llama-3.1-8b": "Llama 3.1 8B",
+  "llama-3.1-70b": "Llama 3.1 70B",
+  "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
+  "llama-3.1-405b": "Llama 3.1 405B",
+  "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
+  "gemma2-9b": "Gemma2 9B",
+  "gemma2-27b": "Gemma2 27B",
+  "nemotron-70b": "Nemotron 70B",
+  "nemotron-70b-bf16": "Nemotron 70B (BF16)",
+  "mistral-nemo": "Mistral Nemo",
+  "mistral-large": "Mistral Large",
+  "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
+  "deepseek-coder-v2.5": "Deepseek Coder V2.5",
+  "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
+  "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
+  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
+  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
+  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
+  "qwen-2.5-7b": "Qwen 2.5 7B",
+  "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
+  "qwen-2.5-14b": "Qwen 2.5 14B",
+  "qwen-2.5-72b": "Qwen 2.5 72B",
+  "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "llama-3-8b": "Llama 3 8B",
+  "llama-3-70b": "Llama 3 70B",
 }
+
+def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
+  return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
+
+def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
+  repo = get_repo(model_id, inference_engine_classname)
+  n_layers = model_cards.get(model_id, {}).get("layers", 0)
+  if repo is None or n_layers < 1:
+    return None
+  return Shard(model_id, 0, 0, n_layers)
+
+def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+  if not supported_inference_engine_lists:
+    return list(model_cards.keys())
+
+  from exo.inference.inference_engine import inference_engine_classes
+  supported_inference_engine_lists = [
+    [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
+    for engine_list in supported_inference_engine_lists
+  ]
+
+  def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
+    return any(engine in model_info.get("repo", {}) for engine in engine_list)
+
+  def supports_all_engine_lists(model_info: dict) -> bool:
+    return all(has_any_engine(model_info, engine_list)
+              for engine_list in supported_inference_engine_lists)
+
+  return [
+    model_id for model_id, model_info in model_cards.items()
+    if supports_all_engine_lists(model_info)
+  ]

+ 11 - 8
exo/networking/grpc/grpc_peer_handle.py

@@ -9,7 +9,7 @@ from . import node_service_pb2_grpc
 from ..peer_handle import PeerHandle
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
-from exo.topology.device_capabilities import DeviceCapabilities
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
 
 
@@ -32,7 +32,11 @@ class GRPCPeerHandle(PeerHandle):
 
   async def connect(self):
     if self.channel is None:
-      self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
+      self.channel = grpc.aio.insecure_channel(self.address, options=[
+        ("grpc.max_metadata_size", 32*1024*1024),
+        ('grpc.max_receive_message_length', 32*1024*1024),
+        ('grpc.max_send_message_length', 32*1024*1024)
+      ])
       self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     await self.channel.channel_ready()
 
@@ -63,10 +67,9 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
       return False
 
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
-      image_str=image_str,
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         start_layer=shard.start_layer,
@@ -74,7 +77,6 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
       ),
       request_id=request_id,
-      inference_state=inference_state,
     )
     response = await self.stub.SendPrompt(request)
 
@@ -83,7 +85,7 @@ class GRPCPeerHandle(PeerHandle):
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -93,7 +95,6 @@ class GRPCPeerHandle(PeerHandle):
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
-      inference_state=inference_state,
     )
     response = await self.stub.SendTensor(request)
 
@@ -117,7 +118,9 @@ class GRPCPeerHandle(PeerHandle):
     response = await self.stub.CollectTopology(request)
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
-      device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
+      device_capabilities = DeviceCapabilities(
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+      )
       topology.update_node(node_id, device_capabilities)
     for node_id, peers in response.peer_graph.items():
       for peer_id in peers.peer_ids:

+ 3 - 5
exo/networking/grpc/grpc_server.py

@@ -49,10 +49,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       n_layers=request.shard.n_layers,
     )
     prompt = request.prompt
-    image_str = request.image_str
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    result = await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
@@ -65,9 +64,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
-    inference_state = request.inference_state
 
-    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
+    result = await self.node.process_tensor(shard, tensor, request_id)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

+ 2 - 5
exo/networking/grpc/node_service.proto

@@ -22,16 +22,13 @@ message Shard {
 message PromptRequest {
   Shard shard = 1;
   string prompt = 2;
-  optional string image_str = 3;
-  optional string request_id = 4;
-  optional string inference_state = 5;
+  optional string request_id = 3;
 }
 
 message TensorRequest {
   Shard shard = 1;
   Tensor tensor = 2;
   optional string request_id = 3;
-  optional string inference_state = 4;
 }
 
 message GetInferenceResultRequest {
@@ -93,4 +90,4 @@ message HealthCheckResponse {
   bool is_healthy = 1;
 }
 
-message Empty {}
+message Empty {}

File diff suppressed because it is too large
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 9 - 7
exo/networking/manual/manual_discovery.py

@@ -19,7 +19,9 @@ class ManualDiscovery(Discovery):
     self.create_peer_handle = create_peer_handle
 
     if node_id not in self.topology.peers:
-      raise ValueError(f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}")
+      raise ValueError(
+        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
+      )
 
     self.listen_task = None
 
@@ -42,7 +44,6 @@ class ManualDiscovery(Discovery):
     if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
     return list(self.known_peers.values())
 
-
   async def task_find_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     while True:
@@ -52,18 +53,19 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)  
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
             self.known_peers[peer_id] = peer
           else:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try: del self.known_peers[peer_id]
-            except KeyError: pass
+            try:
+              del self.known_peers[peer_id]
+            except KeyError:
+              pass
         except Exception as e:
-            if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
       await asyncio.sleep(1.0)
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
-

+ 0 - 1
exo/networking/manual/network_topology_config.py

@@ -17,7 +17,6 @@ class NetworkTopology(BaseModel):
   """
   node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
   """
-
   @classmethod
   def from_path(cls, path: str) -> "NetworkTopology":
     try:

+ 3 - 2
exo/networking/peer_handle.py

@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 
+
 class PeerHandle(ABC):
   @abstractmethod
   def id(self) -> str:
@@ -35,11 +36,11 @@ class PeerHandle(ABC):
     pass
 
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod
-  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
     pass
 
   @abstractmethod

+ 11 - 11
exo/networking/tailscale/tailscale_discovery.py

@@ -8,6 +8,7 @@ from exo.topology.device_capabilities import DeviceCapabilities, device_capabili
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
 from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
 
+
 class TailscaleDiscovery(Discovery):
   def __init__(
     self,
@@ -69,14 +70,11 @@ class TailscaleDiscovery(Discovery):
         devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
         current_time = time.time()
 
-        active_devices = {
-          name: device for name, device in devices.items()
-          if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30
-        }
+        active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
 
         if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
         if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
-        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time  - device.last_seen.timestamp()) for device in devices.values()])
+        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time - device.last_seen.timestamp()) for device in devices.values()])
 
         for device in active_devices.values():
           if device.name == self.node_id: continue
@@ -141,7 +139,13 @@ class TailscaleDiscovery(Discovery):
         for peer_id, should_remove in zip(peer_ids, results):
           if should_remove: peers_to_remove.append(peer_id)
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}" for peer_handle, connected_at, last_seen in self.known_peers.values() })
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}"
+              for peer_handle, connected_at, last_seen in self.known_peers.values()
+            }
+          )
 
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers:
@@ -164,9 +168,5 @@ class TailscaleDiscovery(Discovery):
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       return True
 
-    should_remove = (
-      (not is_connected and current_time - connected_at > self.discovery_timeout) or
-      (current_time - last_seen > self.discovery_timeout) or
-      (not health_ok)
-    )
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
     return should_remove

+ 15 - 26
exo/networking/tailscale/tailscale_helpers.py

@@ -7,6 +7,7 @@ from exo.helpers import DEBUG_DISCOVERY
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from datetime import datetime, timezone
 
+
 class Device:
   def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
     self.device_id = device_id
@@ -16,12 +17,7 @@ class Device:
 
   @classmethod
   def from_dict(cls, data: Dict[str, Any]) -> 'Device':
-    return cls(
-      device_id=data.get('id', ''),
-      name=data.get('name', ''),
-      addresses=data.get('addresses', []),
-      last_seen=cls.parse_datetime(data.get('lastSeen'))
-    )
+    return cls(device_id=data.get('id', ''), name=data.get('name', ''), addresses=data.get('addresses', []), last_seen=cls.parse_datetime(data.get('lastSeen')))
 
   @staticmethod
   def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
@@ -29,13 +25,10 @@ class Device:
       return None
     return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
 
+
 async def get_device_id() -> str:
   try:
-    process = await asyncio.create_subprocess_exec(
-      'tailscale', 'status', '--json',
-      stdout=asyncio.subprocess.PIPE,
-      stderr=asyncio.subprocess.PIPE
-    )
+    process = await asyncio.create_subprocess_exec('tailscale', 'status', '--json', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
     stdout, stderr = await process.communicate()
     if process.returncode != 0:
       raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
@@ -45,22 +38,16 @@ async def get_device_id() -> str:
   except Exception as e:
     raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
 
+
 async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
   async with aiohttp.ClientSession() as session:
     base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}',
-      'Content-Type': 'application/json'
-    }
+    headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
 
     attributes = {
-      "custom:exo_node_id": node_id.replace('-', '_'),
-      "custom:exo_node_port": node_port,
-      "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
-      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model),
-      "custom:exo_device_capability_memory": str(device_capabilities.memory),
-      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16),
-      "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
+      "custom:exo_node_id": node_id.replace('-', '_'), "custom:exo_node_port": node_port, "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
+      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model), "custom:exo_device_capability_memory": str(device_capabilities.memory),
+      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16), "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
       "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
     }
 
@@ -73,12 +60,11 @@ async def update_device_attributes(device_id: str, api_key: str, node_id: str, n
         else:
           print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
 
+
 async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}'
-    }
+    headers = {'Authorization': f'Bearer {api_key}'}
     async with session.get(url, headers=headers) as response:
       if response.status == 200:
         data = await response.json()
@@ -100,6 +86,7 @@ async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int,
         print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
         return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
 
+
 def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
   result = {}
   prefix = "custom:exo_"
@@ -112,12 +99,14 @@ def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
         result[attr_name] = float(value)
   return result
 
+
 def sanitize_attribute(value: str) -> str:
   # Replace invalid characters with underscores
   sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
   # Truncate to 50 characters
   return sanitized_value[:50]
 
+
 async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
@@ -133,4 +122,4 @@ async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]
         device = Device.from_dict(device_data)
         devices[device.name] = device
 
-      return devices
+      return devices

+ 2 - 0
exo/networking/tailscale/test_tailscale_discovery.py

@@ -5,6 +5,7 @@ from unittest import mock
 from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.peer_handle import PeerHandle
 
+
 class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
@@ -37,5 +38,6 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     # Check if discovered peers are instances of GRPCPeerHandle
     print(peers)
 
+
 if __name__ == '__main__':
   unittest.main()

+ 14 - 15
exo/networking/udp/udp_discovery.py

@@ -9,6 +9,7 @@ from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
 
+
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
@@ -90,17 +91,13 @@ class UDPDiscovery(Discovery):
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1, # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
         })
         if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
 
         transport = None
         try:
-          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-            lambda: BroadcastProtocol(message, self.broadcast_port),
-            local_addr=(addr, 0),
-            family=socket.AF_INET
-          )
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
           if DEBUG_DISCOVERY >= 3:
             print(f"Broadcasting presence at ({addr})")
         except Exception as e:
@@ -145,7 +142,8 @@ class UDPDiscovery(Discovery):
         if peer_id in self.known_peers:
           existing_peer_prio = self.known_peers[peer_id][3]
           if existing_peer_prio >= peer_prio:
-            if DEBUG >= 1: print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
+            if DEBUG >= 1:
+              print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
             return
         new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
         if not await new_peer_handle.health_check():
@@ -161,8 +159,7 @@ class UDPDiscovery(Discovery):
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
 
   async def task_listen_for_peers(self):
-    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
-                                                            local_addr=("0.0.0.0", self.listen_port))
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
     if DEBUG_DISCOVERY >= 2: print("Started listen task")
 
   async def task_cleanup_peers(self):
@@ -177,7 +174,13 @@ class UDPDiscovery(Discovery):
         for peer_id, should_remove in zip(peer_ids, results):
           if should_remove: peers_to_remove.append(peer_id)
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}" for peer_handle, connected_at, last_seen, prio in self.known_peers.values() })
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}"
+              for peer_handle, connected_at, last_seen, prio in self.known_peers.values()
+            }
+          )
 
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers:
@@ -200,9 +203,5 @@ class UDPDiscovery(Discovery):
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       return True
 
-    should_remove = (
-      (not is_connected and current_time - connected_at > self.discovery_timeout) or
-      (current_time - last_seen > self.discovery_timeout) or
-      (not health_ok)
-    )
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
     return should_remove

+ 2 - 2
exo/orchestration/node.py

@@ -16,11 +16,11 @@ class Node(ABC):
     pass
 
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod
-  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     pass
 
   @abstractmethod

+ 112 - 122
exo/orchestration/standard_node.py

@@ -39,6 +39,8 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.buffered_logits: Dict[str, List[np.ndarray]] = {}
+    self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
@@ -87,24 +89,53 @@ class StandardNode(Node):
   def get_supported_inference_engines(self):
     supported_engine_names = []
     if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
-        supported_engine_names.append('mlx')
-        supported_engine_names.append('tinygrad')
+      supported_engine_names.append('mlx')
+      supported_engine_names.append('tinygrad')
     else:
-        supported_engine_names.append('tinygrad')
+      supported_engine_names.append('tinygrad')
     return supported_engine_names
 
   async def broadcast_supported_engines(self, supported_engines_names: List[str]):
-    status_message = json.dumps({
-        "type": "supported_inference_engines",
-        "node_id": self.id,
-        "engines": supported_engines_names
-    })
+    status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
     await self.broadcast_opaque_status("", status_message)
 
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
+  
+  async def process_inference_result(
+    self,
+    shard,
+    result: np.ndarray,
+    request_id: Optional[str] = None,
+  ):
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if shard.is_last_layer() and not is_finished:
+      token = await self.inference_engine.sample(result)
+      await self.inference_engine.ensure_shard(shard)
+      self.buffered_token_output[request_id][0].append(token.item())
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+      forward = token.reshape(1, -1)
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
+    else:
+      forward = result
 
-  async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    if is_finished:
+      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+    else:
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+
+    return np.array(self.buffered_token_output[request_id][0])
+
+  async def process_prompt(
+    self,
+    base_shard: Shard,
+    prompt: str,
+    request_id: Optional[str] = None,
+  ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
       self.broadcast_opaque_status(
@@ -116,14 +147,12 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
-          "image_str": image_str,
-          "inference_state": inference_state,
           "request_id": request_id,
         }),
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -136,8 +165,6 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
-          "image_str": image_str,
-          "inference_state": inference_state,
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
           "result_size": resp.size if resp is not None else 0,
@@ -146,42 +173,26 @@ class StandardNode(Node):
     )
     return resp
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
 
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
-    if shard.start_layer != 0:
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
-      await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
-      return
-
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
-    is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:
-      self.buffered_token_output[request_id][0].append(result.item())
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-    if not is_finished:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
-
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+    if not shard.is_first_layer():
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      resp = await self.forward_prompt(shard, prompt, request_id, 0)
+      return None
+    else:
+      result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
+      ret = await self.process_inference_result(shard, result, request_id) 
+      return result
 
   async def process_tensor(
     self,
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
@@ -196,12 +207,11 @@ class StandardNode(Node):
           "tensor_size": tensor.size,
           "tensor_shape": tensor.shape,
           "request_id": request_id,
-          "inference_state": inference_state,
         }),
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
+    resp = await self._process_tensor(shard, tensor, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -226,84 +236,77 @@ class StandardNode(Node):
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
 
+    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
-      if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-      result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-      is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if is_finished:
-        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-      if result.size == 1:  # we got a new token out
-        self.buffered_token_output[request_id][0].append(result.item())
-        self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-      if not is_finished:
-        asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
+      ret = await self.process_inference_result(shard, result, request_id) 
+      return ret
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       return None
 
-  async def forward_to_next_shard(
+  async def forward_prompt(
     self,
     base_shard: Shard,
-    tensor_or_prompt: Union[np.ndarray, str],
+    prompt: str,
     request_id: str,
-    image_str: Optional[str] = None,
-    inference_state: Optional[str] = None,
+    target_index: int,
   ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_prompt(next_shard, prompt, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+  
+  async def forward_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: str,
+    target_index: int,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_tensor(next_shard, tensor, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
+
+  def get_partition_index(self, offset: int = 0):
     if not self.partitioning_strategy:
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-      return
-    shard = self.get_current_shard(base_shard)
-
+      return None
     partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-    if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
-    if current_partition_index is not None:
-      next_partition_index = (current_partition_index+1) % len(partitions)
-      next_partition: Partition = partitions[next_partition_index]
-      next_shard = shards[next_partition_index]
-      if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
-
-      if next_partition.node_id == self.id:
-        if isinstance(tensor_or_prompt, np.ndarray):
-          await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-        else:
-          await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
-        return
-
-      target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {next_partition} not found")
-
-      if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
-
-      if isinstance(tensor_or_prompt, np.ndarray):
-        await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
-      else:
-        await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
+    if current_partition_index is None:
+      raise ValueError(f"No current partition found for node: {self.id}")
+    return (current_partition_index + offset) % len(partitions)
 
-  def get_current_shard(self, base_shard: Shard) -> Shard:
+  def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
+    if index is None:
+      index = self.get_partition_index()
     partitions = self.partitioning_strategy.partition(self.topology)
     shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
-    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-    if current_partition_index is None:
-      raise ValueError(f"No current partition found for node: {self.id}")
-    return shards[current_partition_index]
+    return shards[index]
 
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
     next_peers = await self.discovery.discover_peers(wait_for_peers)
@@ -311,20 +314,16 @@ class StandardNode(Node):
     next_peer_ids = {peer.id() for peer in next_peers}
     peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
     peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
-    peers_updated = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
-    peers_unchanged = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
+    peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
+    peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
     peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
     peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
 
     def _pretty(peers: List[PeerHandle]) -> List[str]:
       return [f"{peer.id()}@{peer.addr()}" for peer in peers]
-    if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
+
+    if DEBUG >= 2:
+      print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
 
     async def disconnect_with_timeout(peer, timeout=5):
       try:
@@ -344,14 +343,8 @@ class StandardNode(Node):
         traceback.print_exc()
         return False
 
-    disconnect_results = await asyncio.gather(
-      *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
-      return_exceptions=True
-    )
-    connect_results = await asyncio.gather(
-      *(connect_with_timeout(peer) for peer in peers_to_connect),
-      return_exceptions=True
-    )
+    disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
+    connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
 
     successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
     failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
@@ -370,12 +363,7 @@ class StandardNode(Node):
     supported_engines = self.get_supported_inference_engines()
     await self.broadcast_supported_engines(supported_engines)
     if len(self.get_topology_inference_engines()):
-      if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
-        if DEBUG >= 1: print("Found node with only tinygrad, using tinygrad on all nodes")
-        self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
-      else:
-        if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
-        self.inference_engine = get_inference_engine("mlx", self.shard_downloader) 
+      self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
 
   async def periodic_topology_collection(self, interval: int):
     while True:
@@ -422,6 +410,7 @@ class StandardNode(Node):
         self.topology.merge(other_topology)
       except Exception as e:
         print(f"Error collecting topology from {peer.id()}: {e}")
+        traceback.print_exc()
 
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
@@ -440,7 +429,7 @@ class StandardNode(Node):
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     self.on_token.trigger_all(request_id, tokens, is_finished)
-
+  
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     async def send_result_to_peer(peer):
       try:
@@ -464,6 +453,7 @@ class StandardNode(Node):
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
         traceback.print_exc()
+
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     # in the case of opaque status, we also want to receive our own opaque statuses
     self.on_opaque_status.trigger_all(request_id, status)

+ 129 - 62
exo/tinychat/index.css

@@ -1,31 +1,11 @@
 /* define colors */
 :root {
-  --primary-color: #a52e4d;
-  --primary-color-transparent: #a52e4d66;
-  --secondary-color: #228039;
-  --secondary-color-transparent: #22803966;
-
+  --primary-color: #fff;
+  --secondary-color: #2a2a2a;
+  --secondary-color-transparent: #ffffff66;
+  --primary-bg-color: #1a1a1a;
+  --foreground-color: #f0f0f0;
   --red-color: #a52e4d;
-  --green-color: #228039;
-  --silver-color: #88808e;
-}
-@media(prefers-color-scheme: light) {
-  :root {
-    --primary-bg-color: #f0f0f0;
-    --secondary-bg-color: #eeeeee;
-    --tertiary-bg-color: #dddddd;
-    --foreground-color: #111111;
-    --accent-color: #000000;
-  }
-}
-@media(prefers-color-scheme: dark) {
-  :root {
-    --primary-bg-color: #111111;
-    --secondary-bg-color: #131313;
-    --tertiary-bg-color: #232323;
-    --foreground-color: #f0f0f0;
-    --accent-color: #aaaaaa;
-  }
 }
 
 main {
@@ -81,7 +61,11 @@ main {
   top: 0;
   position: absolute;
 
-  background: linear-gradient(180deg, var(--primary-bg-color) 0%, transparent 100%);
+  background: linear-gradient(
+    180deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
 }
 .histories-end {
   height: 3rem;
@@ -91,7 +75,11 @@ main {
   bottom: 0;
   position: absolute;
 
-  background: linear-gradient(0deg, var(--primary-bg-color) 0%, transparent 100%);
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
 }
 
 .history {
@@ -99,7 +87,7 @@ main {
   width: 100%;
   max-width: 40rem;
 
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   border-radius: 10px;
   border-left: 2px solid var(--primary-color);
 
@@ -109,7 +97,7 @@ main {
   opacity: var(--opacity, 1);
 }
 .history:hover {
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
 }
 
 .history-delete-button {
@@ -120,14 +108,14 @@ main {
   margin: 0;
   outline: none;
   border: none;
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   border-radius: 0 0 0 10px;
   cursor: pointer;
   transition: 0.2s;
 }
 .history-delete-button:hover {
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   padding: 0.75rem;
 }
 
@@ -135,6 +123,7 @@ main {
   overflow-y: auto;
   height: 100%;
   width: 100%;
+  max-width: 1200px;
 
   display: flex;
   flex-direction: column;
@@ -145,28 +134,25 @@ main {
 }
 
 .message {
-  width: 96%;
-  max-width: 80rem;
-
-  display: grid;
-
-  background-color: var(--secondary-bg-color);
+  max-width: 75%;
   padding: 0.5rem 1rem;
-  border-radius: 10px;
+  border-radius: 20px;
 }
 .message-role-assistant {
-  border-bottom: 2px solid var(--primary-color);
-  border-left: 2px solid var(--primary-color);
-  box-shadow: -10px 10px 20px 2px var(--primary-color-transparent);
+  background-color: var(--secondary-color);
+  margin-right: auto;
+  color: #fff;
 }
 .message-role-user {
-  border-bottom: 2px solid var(--secondary-color);
-  border-right: 2px solid var(--secondary-color);
-  box-shadow: 10px 10px 20px 2px var(--secondary-color-transparent);
+  margin-left: auto;
+  background-color: var(--primary-color);
+  color: #000;
 }
 .download-progress {
   margin-bottom: 12em;
   overflow-y: auto;
+  min-height: 350px;
+  padding: 2rem;
 }
 .message > pre {
   white-space: pre-wrap;
@@ -191,17 +177,70 @@ main {
 }
 
 .toast {
-    width: 100%; /* Take up the full width of the page */
-    background-color: #fc2a2a; /* Dark background color */
-    color: #fff; /* White text color */
-    text-align: center; /* Centered text */
-    border-radius: 2px; /* Rounded borders */
-    padding: 16px; /* Padding */
-    position: fixed; /* Sit on top of the screen content */
-    z-index: 9999; /* Add a z-index if needed */
-    top: 0; /* Position at the top of the page */
-    left: 0; /* Extend from the left edge */
-    right: 0; /* Extend to the right edge */
+    width: 100%;
+    background-color: #fc2a2a;
+    color: #fff;
+    text-align: left;
+    border-radius: 2px;
+    padding: 16px;
+    position: fixed;
+    z-index: 9999;
+    top: 0;
+    left: 0;
+    right: 0;
+    display: flex;
+    flex-direction: column;
+    white-space: pre-wrap;
+    font-family: monospace;
+}
+
+.toast-header {
+    display: flex;
+    justify-content: space-between;
+    align-items: center;
+    width: 100%;
+}
+
+.toast-error-message {
+    flex-grow: 1;
+}
+
+.toast-header-buttons {
+    display: flex;
+    align-items: center;
+    gap: 16px;
+    margin-left: 24px;
+}
+
+.toast-expand-button {
+    background: none;
+    border: none;
+    color: white;
+    padding: 4px;
+    cursor: pointer;
+    font-size: 1em;
+}
+
+.toast-close-button {
+    background: none;
+    border: none;
+    color: white;
+    padding: 4px;
+    cursor: pointer;
+    font-size: 1.2em;
+    line-height: 1;
+}
+
+.toast-expand-button:hover,
+.toast-close-button:hover {
+    opacity: 0.8;
+}
+
+.toast-content {
+    margin-top: 10px;
+    padding: 10px;
+    background-color: rgba(0, 0, 0, 0.2);
+    border-radius: 4px;
 }
 
 .hljs {
@@ -220,14 +259,14 @@ main {
   margin: 0;
   outline: none;
   border: none;
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   border-radius: 0 0 0 10px;
   cursor: pointer;
   transition: 0.2s;
 }
 .clipboard-button:hover {
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   padding: 0.75rem;
 }
 
@@ -236,9 +275,14 @@ main {
   bottom: 0;
 
   /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(0deg, var(--primary-bg-color) 55%, transparent 100%);
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
 
   width: 100%;
+  max-width: 1200px;
   display: flex;
   flex-direction: column;
   justify-content: center;
@@ -285,7 +329,7 @@ main {
   min-height: 3rem;
   max-height: 8rem;
 
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   border-radius: 10px;
   border: none;
@@ -297,8 +341,8 @@ main {
   height: 3rem;
   width: 4rem;
 
-  background-color: var(--secondary-color);
-  color: var(--foreground-color);
+  background-color: var(--primary-color);
+  color: var(--secondary-color);
   border-radius: 10px;
   padding: 0.5rem;
   cursor: pointer;
@@ -307,7 +351,7 @@ main {
   background-color: var(--secondary-color-transparent);
 }
 .input-button:disabled {
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   cursor: not-allowed;
 }
 
@@ -414,4 +458,27 @@ p {
   max-width: 100%;
   max-height: 100%;
   object-fit: contain;
+}
+
+.clear-history-button {
+  background-color: var(--red-color);
+  color: white;
+  padding: 10px 20px;
+  border-radius: 5px;
+  display: flex;
+  align-items: center;
+  gap: 8px;
+  transition: all 0.3s ease;
+  margin: 1rem auto;
+  border: none;
+  cursor: pointer;
+}
+
+.clear-history-button:hover {
+  opacity: 0.8;
+  transform: scale(1.05);
+}
+
+.clear-history-button i {
+  font-size: 14px;
 }

+ 26 - 25
exo/tinychat/index.html

@@ -26,33 +26,27 @@
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
      <!-- Error Toast -->
-    <div x-show="errorMessage" x-transition.opacity x-text="errorMessage" class="toast">
+    <div x-show="errorMessage" x-transition.opacity class="toast">
+        <div class="toast-header">
+            <span class="toast-error-message" x-text="errorMessage.basic"></span>
+            <div class="toast-header-buttons">
+                <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
+                        class="toast-expand-button" 
+                        x-show="errorMessage.stack">
+                    <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
+                </button>
+                <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
+                    <i class="fas fa-times"></i>
+                </button>
+            </div>
+        </div>
+        <div class="toast-content" x-show="errorExpanded" x-transition>
+            <span x-text="errorMessage.stack"></span>
+        </div>
     </div>
 <div class="model-selector">
-<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel">
-<option selected="" value="llama-3.2-1b">Llama 3.2 1B</option>
-<option value="llama-3.2-3b">Llama 3.2 3B</option>
-<option value="llama-3.1-8b">Llama 3.1 8B</option>
-<option value="llama-3.1-70b">Llama 3.1 70B</option>
-<option value="llama-3.1-70b-bf16">Llama 3.1 70B (BF16)</option>
-<option value="llama-3.1-405b">Llama 3.1 405B</option>
-<option value="llama-3-8b">Llama 3 8B</option>
-<option value="llama-3-70b">Llama 3 70B</option>
-<option value="nemotron-70b">Nemotron 70B</option>
-<option value="nemotron-70b-bf16">Nemotron 70B (BF16)</option>
-<option value="mistral-nemo">Mistral Nemo</option>
-<option value="mistral-large">Mistral Large</option>
-<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
-<option value="deepseek-coder-v2.5">Deepseek Coder V2.5</option>
-<option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
-<option value="qwen-2.5-coder-1.5b">Qwen 2.5 Coder 1.5B</option>
-<option value="qwen-2.5-coder-7b">Qwen 2.5 Coder 7B</option>
-<option value="qwen-2.5-7b">Qwen 2.5 7B</option>
-<option value="qwen-2.5-math-7b">Qwen 2.5 7B (Math)</option>
-<option value="qwen-2.5-14b">Qwen 2.5 14B</option>
-<option value="qwen-2.5-72b">Qwen 2.5 72B</option>
-<option value="qwen-2.5-math-72b">Qwen 2.5 72B (Math)</option>
-</select>
+  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
+  </select>
 </div>
 <div @popstate.window="
       if (home === 2) {
@@ -68,6 +62,13 @@
       if (home === -1) setTimeout(() =&gt; home = 0, 100);
     " x-show="home === 0" x-transition="">
 <h1 class="title megrim-regular">tinychat</h1>
+<template x-if="histories.length">
+  <button 
+    @click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();" 
+    class="clear-history-button">
+    <i class="fas fa-trash"></i> Clear All History
+  </button>
+</template>
 <div class="histories-container-container">
 <template x-if="histories.length">
 <div class="histories-start"></div>

+ 109 - 14
exo/tinychat/index.js

@@ -4,8 +4,8 @@ document.addEventListener("alpine:init", () => {
     cstate: {
       time: null,
       messages: [],
-      selectedModel: 'llama-3.1-8b',
-    },
+      selectedModel: 'llama-3.2-1b',
+    },    
 
     // historical state
     histories: JSON.parse(localStorage.getItem("histories")) || [],
@@ -14,6 +14,8 @@ document.addEventListener("alpine:init", () => {
     generating: false,
     endpoint: `${window.location.origin}/v1`,
     errorMessage: null,
+    errorExpanded: false,
+    errorTimeout: null,
 
     // performance tracking
     time_till_first: 0,
@@ -47,6 +49,12 @@ document.addEventListener("alpine:init", () => {
         localStorage.setItem("histories", JSON.stringify(this.histories));
       }
     },
+
+    clearAllHistory() {
+      this.histories = [];
+      localStorage.setItem("histories", JSON.stringify([]));
+    },
+
     // Utility functions
     formatBytes(bytes) {
       if (bytes === 0) return '0 B';
@@ -66,6 +74,56 @@ document.addEventListener("alpine:init", () => {
       return `${s}s`;
     },
 
+    async populateSelector() {
+      try {
+        const response = await fetch(`${window.location.origin}/modelpool`);
+        const responseText = await response.text(); // Get raw response text first
+        
+        if (!response.ok) {
+          throw new Error(`HTTP error! status: ${response.status}`);
+        }
+        
+        // Try to parse the response text
+        let responseJson;
+        try {
+          responseJson = JSON.parse(responseText);
+        } catch (parseError) {
+          console.error('Failed to parse JSON:', parseError);
+          throw new Error(`Invalid JSON response: ${responseText}`);
+        }
+
+        const sel = document.querySelector(".model-select");
+        if (!sel) {
+          throw new Error("Could not find model selector element");
+        }
+
+        // Clear the current options and add new ones
+        sel.innerHTML = '';
+          
+        const modelDict = responseJson["model pool"];
+        if (!modelDict) {
+          throw new Error("Response missing 'model pool' property");
+        }
+
+        Object.entries(modelDict).forEach(([key, value]) => {
+          const opt = document.createElement("option");
+          opt.value = key;
+          opt.textContent = value;
+          sel.appendChild(opt);
+        });
+
+        // Set initial value to the first model
+        const firstKey = Object.keys(modelDict)[0];
+        if (firstKey) {
+          sel.value = firstKey;
+          this.cstate.selectedModel = firstKey;
+        }
+      } catch (error) {
+        console.error("Error populating model selector:", error);
+        this.errorMessage = `Failed to load models: ${error.message}`;
+      }
+    },
+
     async handleImageUpload(event) {
       const file = event.target.files[0];
       if (file) {
@@ -110,12 +168,30 @@ document.addEventListener("alpine:init", () => {
         localStorage.setItem("pendingMessage", value);
         this.processMessage(value);
       } catch (error) {
-        console.error('error', error)
-        this.lastErrorMessage = error.message || 'Unknown error on handleSend';
-        this.errorMessage = error.message || 'Unknown error on handleSend';
-        setTimeout(() => {
-          this.errorMessage = null;
-        }, 5 * 1000)
+        console.error('error', error);
+        const errorDetails = {
+            message: error.message || 'Unknown error',
+            stack: error.stack,
+            name: error.name || 'Error'
+        };
+        
+        this.errorMessage = {
+            basic: `${errorDetails.name}: ${errorDetails.message}`,
+            stack: errorDetails.stack
+        };
+
+        // Clear any existing timeout
+        if (this.errorTimeout) {
+            clearTimeout(this.errorTimeout);
+        }
+
+        // Only set the timeout if the error details aren't expanded
+        if (!this.errorExpanded) {
+            this.errorTimeout = setTimeout(() => {
+                this.errorMessage = null;
+                this.errorExpanded = false;
+            }, 30 * 1000);
+        }
         this.generating = false;
       }
     },
@@ -232,12 +308,30 @@ document.addEventListener("alpine:init", () => {
           console.error("Failed to save histories to localStorage:", error);
         }
       } catch (error) {
-        console.error('error', error)
-        this.lastErrorMessage = error;
-        this.errorMessage = error;
-        setTimeout(() => {
-          this.errorMessage = null;
-        }, 5 * 1000)
+        console.error('error', error);
+        const errorDetails = {
+            message: error.message || 'Unknown error',
+            stack: error.stack,
+            name: error.name || 'Error'
+        };
+        
+        this.errorMessage = {
+            basic: `${errorDetails.name}: ${errorDetails.message}`,
+            stack: errorDetails.stack
+        };
+
+        // Clear any existing timeout
+        if (this.errorTimeout) {
+            clearTimeout(this.errorTimeout);
+        }
+
+        // Only set the timeout if the error details aren't expanded
+        if (!this.errorExpanded) {
+            this.errorTimeout = setTimeout(() => {
+                this.errorMessage = null;
+                this.errorExpanded = false;
+            }, 30 * 1000);
+        }
       } finally {
         this.generating = false;
       }
@@ -529,6 +623,7 @@ function createParser(onParse) {
     }
   }
 }
+
 const BOM = [239, 187, 191];
 function hasBom(buffer) {
   return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);

+ 44 - 41
exo/tinychat/update_deps.py

@@ -4,49 +4,52 @@ from bs4 import BeautifulSoup
 from urllib.parse import urljoin, urlparse
 import re
 
+
 def download_file(url, local_path):
-    response = requests.get(url)
-    if response.status_code == 200:
-        os.makedirs(os.path.dirname(local_path), exist_ok=True)
-        with open(local_path, 'wb') as f:
-            f.write(response.content)
-        print(f"Downloaded: {local_path}")
-    else:
-        print(response.status_code)
-        print(f"Failed to download: {url}")
+  response = requests.get(url)
+  if response.status_code == 200:
+    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+    with open(local_path, 'wb') as f:
+      f.write(response.content)
+    print(f"Downloaded: {local_path}")
+  else:
+    print(response.status_code)
+    print(f"Failed to download: {url}")
+
 
 def update_html(html_content, base_url):
-    soup = BeautifulSoup(html_content, 'html.parser')
+  soup = BeautifulSoup(html_content, 'html.parser')
 
-    for tag in soup.find_all(['script', 'link']):
-        if tag.has_attr('src'):
-            url = tag['src']
-        elif tag.has_attr('href'):
-            url = tag['href']
-        else:
-            continue
+  for tag in soup.find_all(['script', 'link']):
+    if tag.has_attr('src'):
+      url = tag['src']
+    elif tag.has_attr('href'):
+      url = tag['href']
+    else:
+      continue
+
+    if url.startswith(('http://', 'https://')):
+      full_url = url
+    else:
+      full_url = urljoin(base_url, url)
 
-        if url.startswith(('http://', 'https://')):
-            full_url = url
-        else:
-            full_url = urljoin(base_url, url)
+    parsed_url = urlparse(full_url)
+    local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
 
-        parsed_url = urlparse(full_url)
-        local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
+    download_file(full_url, local_path)
 
-        download_file(full_url, local_path)
+    relative_path = os.path.relpath(local_path, '.')
+    if tag.name == 'script':
+      tag['src'] = "/" + relative_path
+    elif tag.name == 'link':
+      tag['href'] = "/" + relative_path
 
-        relative_path = os.path.relpath(local_path, '.')
-        if tag.name == 'script':
-            tag['src'] = "/" + relative_path
-        elif tag.name == 'link':
-            tag['href'] = "/" + relative_path
+  return str(soup)
 
-    return str(soup)
 
 # Read the HTML file
 with open('./index.html', 'r') as f:
-    html_content = f.read()
+  html_content = f.read()
 
 # Update HTML and download files
 # updated_html = update_html(html_content, 'https://example.com')
@@ -68,7 +71,7 @@ download_file(css_url, css_output_path)
 
 # Parse CSS file for font URLs
 with open(css_output_path, 'r', encoding='utf-8') as f:
-    css_content = f.read()
+  css_content = f.read()
 
 # Extract font URLs from the CSS content
 font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
@@ -77,14 +80,14 @@ print(f"Found {len(font_urls)} font URLs")
 
 # Download font files
 for font_url in font_urls:
-    font_url = font_url.strip('"\'')
-    if font_url.startswith('../'):
-        font_url = font_url[3:]
+  font_url = font_url.strip('"\'')
+  if font_url.startswith('../'):
+    font_url = font_url[3:]
 
-    # Use base_url instead of urljoin to keep the version number
-    full_url = base_url + font_url
-    relative_path = font_url
-    output_path = os.path.join(output_dir, relative_path)
-    download_file(full_url, output_path)
+  # Use base_url instead of urljoin to keep the version number
+  full_url = base_url + font_url
+  relative_path = font_url
+  output_path = os.path.join(output_dir, relative_path)
+  download_file(full_url, output_path)
 
-print("Download complete!")
+print("Download complete!")

+ 4 - 2
exo/topology/device_capabilities.py

@@ -52,9 +52,11 @@ CHIP_FLOPS = {
   "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
   "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
   "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
   "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
-  "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
+  "Apple M4": DeviceFlops(fp32=4.26*TFLOPS, fp16=8.52*TFLOPS, int8=17.04*TFLOPS),
+  "Apple M4 Pro": DeviceFlops(fp32=5.72*TFLOPS, fp16=11.44*TFLOPS, int8=22.88*TFLOPS),
+  "Apple M4 Max": DeviceFlops(fp32=18.03*TFLOPS, fp16=36.07*TFLOPS, int8=72.14*TFLOPS),
   ### A chips
   "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
   "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),

+ 1 - 1
extra/start_openwebui.sh

@@ -1,3 +1,3 @@
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 echo "Using API_ENDPOINT=${API_ENDPOINT}"
 docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main

+ 1 - 1
format.py

@@ -21,7 +21,7 @@ def run_yapf(target):
 
 def main():
   if len(sys.argv) < 2:
-    print("Usage: python format.py <directory_or_file>")
+    print("Usage: python3 format.py <directory_or_file> e.g. python3 format.py ./exo")
     sys.exit(1)
 
   target = sys.argv[1]

+ 0 - 5
lint.sh

@@ -1,5 +0,0 @@
-#!/bin/bash
-
-pip3 install -e '.[linting]'
-python3 -m ruff check .
-python3 -m pylint .

+ 0 - 7
pyproject.toml

@@ -1,7 +0,0 @@
-[tool.pylint.format]
-indent-string = '  '
-max-line-length = 200
-
-[tool.autopep8]
-max_line_length = 200
-indent_size = 2

+ 0 - 43
ruff.toml

@@ -1,43 +0,0 @@
-indent-width = 2
-preview = true
-target-version = "py312"
-
-lint.select = [
-  "F",  # Pyflakes
-  "W6",
-  "E71",
-  "E72",
-  "E112",   # no-indented-block
-  "E113",   # unexpected-indentation
-  # "E124",
-  "E203",   # whitespace-before-punctuation
-  "E272",   # multiple-spaces-before-keyword
-  "E303",   # too-many-blank-lines
-  "E304",   # blank-line-after-decorator
-  "E501",   # line-too-long
-  # "E502",
-  "E702",   # multiple-statements-on-one-line-semicolon
-  "E703",   # useless-semicolon
-  "E731",   # lambda-assignment
-  "W191",   # tab-indentation
-  "W291",   # trailing-whitespace
-  "W293",   # blank-line-with-whitespace
-  "UP039",  # unnecessary-class-parentheses
-  "C416",   # unnecessary-comprehension
-  "RET506", # superfluous-else-raise
-  "RET507", # superfluous-else-continue
-  "A",      # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
-  "SIM105", # suppressible-exception
-  "FURB110",# if-exp-instead-of-or-operator
-]
-
-line-length = 200
-
-exclude = [
-  "docs/",
-  "examples/",
-  "extra/",
-  "exo/networking/grpc/node_service_pb2.py",
-  "exo/networking/grpc/node_service_pb2_grpc.py",
-  "exo/helpers.py",
-]

+ 60 - 0
scripts/build_exo.py

@@ -0,0 +1,60 @@
+import site
+import subprocess
+import sys
+import os 
+import pkgutil
+
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--onefile",
+        "--python-flag=no_site"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "--macos-signed-app-name=com.exolabs.exo",
+            "--macos-sign-identity=auto",
+            "--macos-sign-notarization",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 7 - 0
scripts/compile_grpc.sh

@@ -0,0 +1,7 @@
+#!/bin/bash
+source ./install.sh
+pushd exo/networking/grpc
+python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
+sed -i "s/import\ node_service_pb2/from . &/" node_service_pb2_grpc.py
+popd
+

+ 10 - 13
setup.py

@@ -5,40 +5,37 @@ from setuptools import find_packages, setup
 
 # Base requirements for all platforms
 install_requires = [
-  "aiohttp==3.10.2",
+  "aiohttp==3.10.11",
   "aiohttp_cors==0.7.0",
   "aiofiles==24.1.0",
-  "grpcio==1.64.1",
-  "grpcio-tools==1.64.1",
+  "grpcio==1.68.0",
+  "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "numpy==2.0.0",
+  "nuitka==2.4.10",
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
-  "protobuf==5.27.1",
+  "protobuf==5.28.1",
   "psutil==6.0.0",
   "pydantic==2.9.2",
   "requests==2.32.3",
   "rich==13.7.1",
-  "safetensors==0.4.3",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
-  "transformers==4.43.3",
+  "transformers==4.46.3",
   "uuid==1.30",
-  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
+  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
 ]
 
 extras_require = {
-  "linting": [
-    "pylint==3.2.6",
-    "ruff==0.5.5",
-    "mypy==1.11.0",
+  "formatting": [
     "yapf==0.40.2",
   ],
   "apple_silicon": [
-    "mlx==0.18.0",
-    "mlx-lm==0.18.2",
+    "mlx==0.20.0",
+    "mlx-lm==0.19.3",
   ],
 }
 

+ 1 - 1
test/reconnect.sh

@@ -1,7 +1,7 @@
 #!/bin/bash
 
 echo "Starting node 1"
-DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
+DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
 PID1=$!
 echo "Started node 1 PID: $PID1"
 echo "Starting node 2"

+ 121 - 0
test/test_model_helpers.py

@@ -0,0 +1,121 @@
+import unittest
+from exo.models import get_supported_models, model_cards
+from exo.inference.inference_engine import inference_engine_classes
+from typing import NamedTuple
+
+class TestCase(NamedTuple):
+  name: str
+  engine_lists: list  # Will contain short names, will be mapped to class names
+  expected_models_contains: list
+  min_count: int | None
+  exact_count: int | None
+  max_count: int | None
+
+# Helper function to map short names to class names
+def expand_engine_lists(engine_lists):
+  def map_engine(engine):
+    return inference_engine_classes.get(engine, engine)  # Return original name if not found
+
+  return [[map_engine(engine) for engine in sublist]
+          for sublist in engine_lists]
+
+test_cases = [
+  TestCase(
+    name="single_mlx_engine",
+    engine_lists=[["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="single_tinygrad_engine",
+    engine_lists=[["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="multiple_engines_or",
+    engine_lists=[["mlx", "tinygrad"], ["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="multiple_engines_all",
+    engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="distinct_engine_lists",
+    engine_lists=[["mlx"], ["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="no_engines",
+    engine_lists=[],
+    expected_models_contains=None,
+    min_count=None,
+    exact_count=len(model_cards),
+    max_count=None
+  ),
+  TestCase(
+    name="nonexistent_engine",
+    engine_lists=[["NonexistentEngine"]],
+    expected_models_contains=[],
+    min_count=None,
+    exact_count=0,
+    max_count=None
+  ),
+  TestCase(
+    name="dummy_engine",
+    engine_lists=[["dummy"]],
+    expected_models_contains=["dummy"],
+    min_count=None,
+    exact_count=1,
+    max_count=None
+  ),
+]
+
+class TestModelHelpers(unittest.TestCase):
+  def test_get_supported_models(self):
+    for case in test_cases:
+      with self.subTest(f"{case.name}_short_names"):
+        result = get_supported_models(case.engine_lists)
+        self._verify_results(case, result)
+
+      with self.subTest(f"{case.name}_class_names"):
+        class_name_lists = expand_engine_lists(case.engine_lists)
+        result = get_supported_models(class_name_lists)
+        self._verify_results(case, result)
+
+  def _verify_results(self, case, result):
+    if case.expected_models_contains:
+      for model in case.expected_models_contains:
+        self.assertIn(model, result)
+
+    if case.min_count:
+      self.assertGreater(len(result), case.min_count)
+
+    if case.exact_count is not None:
+      self.assertEqual(len(result), case.exact_count)
+
+    # Special case for distinct lists test
+    if case.name == "distinct_engine_lists":
+      self.assertLess(len(result), 10)
+      self.assertNotIn("mistral-nemo", result)
+
+    if case.max_count:
+      self.assertLess(len(result), case.max_count)
+
+if __name__ == '__main__':
+  unittest.main()

+ 8 - 3
test/test_tokenizers.py

@@ -1,7 +1,7 @@
 import os
 import re
 from transformers import AutoTokenizer, AutoProcessor
-from exo.models import model_base_shards
+from exo.models import model_cards
 
 
 def test_tokenizer(name, tokenizer, verbose=False):
@@ -24,9 +24,14 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
-models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if not ignore_pattern.match(shard.model_id)]
+models = []
+for model_id in model_cards:
+  for engine_type, repo_id in model_cards[model_id].get("repo", {}).items():
+    if not ignore_pattern.match(repo_id):
+      models.append(repo_id)
+models = list(set(models))
 
 verbose = os.environ.get("VERBOSE", "0").lower() == "1"
 for m in models:

Some files were not shown because too many files changed in this diff