Browse Source

Merge branch 'main' into main

Alex Cheema 6 months ago
parent
commit
675b1e16a8
64 changed files with 1584 additions and 1402 deletions
  1. 35 35
      .circleci/config.yml
  2. 1 1
      .gitignore
  3. 0 472
      .pylintrc
  4. 25 5
      README.md
  5. 18 2
      configure_mlx.sh
  6. BIN
      docs/exo-rounded.png
  7. 1 1
      examples/astra/astra/ContentView.swift
  8. 1 1
      examples/chatgpt_api.sh
  9. 1 1
      exo/__init__.py
  10. 53 41
      exo/api/chatgpt_api.py
  11. 34 6
      exo/download/hf/hf_helpers.py
  12. 8 6
      exo/download/hf/hf_shard_download.py
  13. 4 2
      exo/download/shard_download.py
  14. 21 1
      exo/helpers.py
  15. 11 12
      exo/inference/debug_inference_engine.py
  16. 17 38
      exo/inference/dummy_inference_engine.py
  17. 22 3
      exo/inference/inference_engine.py
  18. 1 1
      exo/inference/mlx/models/base.py
  19. 1 1
      exo/inference/mlx/models/deepseek_v2.py
  20. 118 0
      exo/inference/mlx/models/gemma2.py
  21. 2 1
      exo/inference/mlx/models/qwen2.py
  22. 46 21
      exo/inference/mlx/sharded_inference_engine.py
  23. 0 86
      exo/inference/mlx/sharded_model.py
  24. 14 8
      exo/inference/mlx/sharded_utils.py
  25. 42 0
      exo/inference/mlx/stateful_model.py
  26. 4 4
      exo/inference/mlx/test_sharded_llama.py
  27. 2 2
      exo/inference/mlx/test_sharded_llava.py
  28. 39 42
      exo/inference/test_dummy_inference_engine.py
  29. 14 24
      exo/inference/test_inference_engine.py
  30. 34 36
      exo/inference/tinygrad/inference.py
  31. 61 36
      exo/inference/tinygrad/models/llama.py
  32. 42 0
      exo/inference/tinygrad/stateful_model.py
  33. 6 1
      exo/inference/tokenizers.py
  34. 76 40
      exo/main.py
  35. 125 50
      exo/models.py
  36. 11 8
      exo/networking/grpc/grpc_peer_handle.py
  37. 3 5
      exo/networking/grpc/grpc_server.py
  38. 2 5
      exo/networking/grpc/node_service.proto
  39. 0 0
      exo/networking/grpc/node_service_pb2.py
  40. 9 7
      exo/networking/manual/manual_discovery.py
  41. 0 1
      exo/networking/manual/network_topology_config.py
  42. 3 2
      exo/networking/peer_handle.py
  43. 11 11
      exo/networking/tailscale/tailscale_discovery.py
  44. 15 26
      exo/networking/tailscale/tailscale_helpers.py
  45. 2 0
      exo/networking/tailscale/test_tailscale_discovery.py
  46. 14 15
      exo/networking/udp/udp_discovery.py
  47. 2 2
      exo/orchestration/node.py
  48. 112 122
      exo/orchestration/standard_node.py
  49. 129 62
      exo/tinychat/index.css
  50. 26 25
      exo/tinychat/index.html
  51. 109 14
      exo/tinychat/index.js
  52. 44 41
      exo/tinychat/update_deps.py
  53. 4 2
      exo/topology/device_capabilities.py
  54. 1 1
      extra/start_openwebui.sh
  55. 1 1
      format.py
  56. 0 5
      lint.sh
  57. 0 7
      pyproject.toml
  58. 0 43
      ruff.toml
  59. 60 0
      scripts/build_exo.py
  60. 7 0
      scripts/compile_grpc.sh
  61. 10 13
      setup.py
  62. 1 1
      test/reconnect.sh
  63. 121 0
      test/test_model_helpers.py
  64. 8 3
      test/test_tokenizers.py

+ 35 - 35
.circleci/config.yml

@@ -20,12 +20,18 @@ commands:
           command: |
           command: |
             source env/bin/activate
             source env/bin/activate
 
 
+            # Set CLANG=1 for tinygrad only
+            if [ "<<parameters.inference_engine>>" = "tinygrad" ]; then
+              pip install llvmlite
+              export TOKENIZERS_PARALLELISM=true SUPPORT_BF16=0 CLANG=1
+            fi
+
             # Start first instance
             # Start first instance
-            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
+            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 --disable-tui 2>&1 | tee output1.log &
             PID1=$!
             PID1=$!
 
 
             # Start second instance
             # Start second instance
-            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log &
+            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 --disable-tui 2>&1 | tee output2.log &
             PID2=$!
             PID2=$!
 
 
             # Wait for discovery
             # Wait for discovery
@@ -48,13 +54,6 @@ commands:
             # Check processes before proceeding
             # Check processes before proceeding
             check_processes
             check_processes
 
 
-            # Special handling for dummy engine
-            if [ "<<parameters.inference_engine>>" = "dummy" ]; then
-              expected_content="This is a dummy response"
-            else
-              expected_content="Michael Jackson"
-            fi
-
             echo "Sending request to first instance..."
             echo "Sending request to first instance..."
             response_1=$(curl -s http://localhost:8000/v1/chat/completions \
             response_1=$(curl -s http://localhost:8000/v1/chat/completions \
               -H "Content-Type: application/json" \
               -H "Content-Type: application/json" \
@@ -127,6 +126,7 @@ jobs:
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             echo "Running tokenizer tests..."
             echo "Running tokenizer tests..."
             python3 ./test/test_tokenizers.py
             python3 ./test/test_tokenizers.py
+            python3 ./test/test_model_helpers.py
 
 
   discovery_integration_test:
   discovery_integration_test:
     macos:
     macos:
@@ -149,9 +149,9 @@ jobs:
           name: Run discovery integration test
           name: Run discovery integration test
           command: |
           command: |
             source env/bin/activate
             source env/bin/activate
-            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --disable-tui > output1.log 2>&1 &
             PID1=$!
             PID1=$!
-            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --disable-tui > output2.log 2>&1 &
             PID2=$!
             PID2=$!
             sleep 10
             sleep 10
             kill $PID1 $PID2
             kill $PID1 $PID2
@@ -223,29 +223,29 @@ jobs:
       - checkout
       - checkout
       - run: system_profiler SPHardwareDataType
       - run: system_profiler SPHardwareDataType
 
 
-  # chatgpt_api_integration_test_tinygrad:
-  #   macos:
-  #     xcode: "16.0.0"
-  #   resource_class: m2pro.large
-  #   steps:
-  #     - checkout
-  #     - run:
-  #         name: Set up Python
-  #         command: |
-  #           brew install python@3.12
-  #           python3.12 -m venv env
-  #           source env/bin/activate
-  #     - run:
-  #         name: Install dependencies
-  #         command: |
-  #           source env/bin/activate
-  #           pip install --upgrade pip
-  #           pip install .
-  #     - run_chatgpt_api_test:
-  #         inference_engine: tinygrad
-  #         model_id: llama-3-8b
-  #         prompt: "Keep responses concise. Who was the king of pop?"
-  #         expected_output: "Michael Jackson"
+  chatgpt_api_integration_test_tinygrad:
+    macos:
+      xcode: "16.0.0"
+    resource_class: m2pro.large
+    steps:
+      - checkout
+      - run:
+          name: Set up Python
+          command: |
+            brew install python@3.12
+            python3.12 -m venv env
+            source env/bin/activate
+      - run:
+          name: Install dependencies
+          command: |
+            source env/bin/activate
+            pip install --upgrade pip
+            pip install .
+      - run_chatgpt_api_test:
+          inference_engine: tinygrad
+          model_id: llama-3.2-1b
+          prompt: "Keep responses concise. Who was the king of pop?"
+          expected_output: "Michael Jackson"
 
 
 workflows:
 workflows:
   version: 2
   version: 2
@@ -254,6 +254,6 @@ workflows:
       - unit_test
       - unit_test
       - discovery_integration_test
       - discovery_integration_test
       - chatgpt_api_integration_test_mlx
       - chatgpt_api_integration_test_mlx
+      - chatgpt_api_integration_test_tinygrad
       - chatgpt_api_integration_test_dummy
       - chatgpt_api_integration_test_dummy
       - test_macos_m1
       - test_macos_m1
-      # - chatgpt_api_integration_test_tinygrad

+ 1 - 1
.gitignore

@@ -4,6 +4,7 @@ test_weights.npz
 .exo_used_ports
 .exo_used_ports
 .exo_node_id
 .exo_node_id
 .idea
 .idea
+.DS_Store
 
 
 # Byte-compiled / optimized / DLL files
 # Byte-compiled / optimized / DLL files
 __pycache__/
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 
 # Distribution / packaging
 # Distribution / packaging
 /.Python
 /.Python
-/build/
 /develop-eggs/
 /develop-eggs/
 /dist/
 /dist/
 /downloads/
 /downloads/

+ 0 - 472
.pylintrc

@@ -1,472 +0,0 @@
-[MASTER]
-
-# A comma-separated list of package or module names from where C extensions may
-# be loaded. Extensions are loading into the active Python interpreter and may
-# run arbitrary code
-extension-pkg-whitelist=scipy,cereal.messaging.messaging_pyx,PyQt5,av
-
-# Add files or directories to the blacklist. They should be base names, not
-# paths.
-ignore=CVS
-
-# Add files or directories matching the regex patterns to the blacklist. The
-# regex matches against base names, not paths.
-ignore-patterns=.*node_service_pb2.*
-
-# Python code to execute, usually for sys.path manipulation such as
-# pygtk.require().
-#init-hook=
-
-# Use multiple processes to speed up Pylint.
-jobs=4
-
-# List of plugins (as comma separated values of python modules names) to load,
-# usually to register additional checkers.
-load-plugins=
-
-# Pickle collected data for later comparisons.
-persistent=yes
-
-# Specify a configuration file.
-#rcfile=
-
-# When enabled, pylint would attempt to guess common misconfiguration and emit
-# user-friendly hints instead of false-positive error messages
-suggestion-mode=yes
-
-# Allow loading of arbitrary C extensions. Extensions are imported into the
-# active Python interpreter and may run arbitrary code.
-unsafe-load-any-extension=no
-
-
-[MESSAGES CONTROL]
-
-# Only show warnings with the listed confidence levels. Leave empty to show
-# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
-confidence=
-
-# Disable the message, report, category or checker with the given id(s). You
-# can either give multiple identifiers separated by comma (,) or put this
-# option multiple times (only on the command line, not in the configuration
-# file where it should appear only once).You can also use "--disable=all" to
-# disable everything first and then reenable specific checks. For example, if
-# you want to run only the similarities checker, you can use "--disable=all
-# --enable=similarities". If you want to run only the classes checker, but have
-# no Warning level messages displayed, use"--disable=all --enable=classes
-# --disable=W"
-disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401
-# E1101 for function binding
-# W0221 for Function class
-# W0105 for comment strings
-# E0401 for missing imports
-
-# Enable the message, report, category or checker with the given id(s). You can
-# either give multiple identifier separated by comma (,) or put this option
-# multiple time (only on the command line, not in the configuration file where
-# it should appear only once). See also the "--disable" option for examples.
-enable=c-extension-no-member,use-a-generator, no-else-return
-
-
-[REPORTS]
-
-# Python expression which should return a note less than 10 (10 is the highest
-# note). You have access to the variables errors warning, statement which
-# respectively contain the number of errors / warnings messages and the total
-# number of statements analyzed. This is used by the global evaluation report
-# (RP0004).
-evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
-
-# Template used to display messages. This is a python new-style format string
-# used to format the message information. See doc for all details
-#msg-template=
-
-# Set the output format. Available formats are text, parseable, colorized, json
-# and msvs (visual studio).You can also give a reporter class, eg
-# mypackage.mymodule.MyReporterClass.
-output-format=text
-
-# Tells whether to display a full report or only the messages
-reports=no
-
-# Activate the evaluation score.
-score=yes
-
-
-[REFACTORING]
-
-# Maximum number of nested blocks for function / method body
-max-nested-blocks=5
-
-# Complete name of functions that never returns. When checking for
-# inconsistent-return-statements if a never returning function is called then
-# it will be considered as an explicit return statement and no message will be
-# printed.
-never-returning-functions=optparse.Values,sys.exit
-
-
-[LOGGING]
-
-# Logging modules to check that the string format arguments are in logging
-# function parameter format
-logging-modules=logging
-
-
-[SPELLING]
-
-# Limits count of emitted suggestions for spelling mistakes
-max-spelling-suggestions=4
-
-# Spelling dictionary name. Available dictionaries: none. To make it working
-# install python-enchant package.
-spelling-dict=
-
-# List of comma separated words that should not be checked.
-spelling-ignore-words=
-
-# A path to a file that contains private dictionary; one word per line.
-spelling-private-dict-file=
-
-# Tells whether to store unknown words to indicated private dictionary in
-# --spelling-private-dict-file option instead of raising a message.
-spelling-store-unknown-words=no
-
-
-[MISCELLANEOUS]
-
-# List of note tags to take in consideration, separated by a comma.
-notes=FIXME,
-      XXX,
-      TODO
-
-
-[SIMILARITIES]
-
-# Ignore comments when computing similarities.
-ignore-comments=yes
-
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
-
-# Ignore imports when computing similarities.
-ignore-imports=no
-
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-
-
-[TYPECHECK]
-
-# List of decorators that produce context managers, such as
-# contextlib.contextmanager. Add to this list to register other decorators that
-# produce valid context managers.
-contextmanager-decorators=contextlib.contextmanager
-
-# List of members which are set dynamically and missed by pylint inference
-# system, and so shouldn't trigger E1101 when accessed. Python regular
-# expressions are accepted.
-generated-members=capnp.* cereal.* pygame.* zmq.* setproctitle.* smbus2.* usb1.* serial.* cv2.* ft4222.* carla.*
-
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members=yes
-
-# This flag controls whether pylint should warn about no-member and similar
-# checks whenever an opaque object is returned when inferring. The inference
-# can return multiple potential results while evaluating a Python object, but
-# some branches might not be evaluated, which results in partial inference. In
-# that case, it might be useful to still emit no-member and other checks for
-# the rest of the inferred objects.
-ignore-on-opaque-inference=yes
-
-# List of class names for which member attributes should not be checked (useful
-# for classes with dynamically set attributes). This supports the use of
-# qualified names.
-ignored-classes=optparse.Values,thread._local,_thread._local
-
-# List of module names for which member attributes should not be checked
-# (useful for modules/projects where namespaces are manipulated during runtime
-# and thus existing member attributes cannot be deduced by static analysis. It
-# supports qualified module names, as well as Unix pattern matching.
-ignored-modules=flask setproctitle usb1 flask.ext.socketio smbus2 usb1.*
-
-# Show a hint with possible names when a member name was not found. The aspect
-# of finding the hint is based on edit distance.
-missing-member-hint=yes
-
-# The minimum edit distance a name should have in order to be considered a
-# similar match for a missing member name.
-missing-member-hint-distance=1
-
-# The total number of similar names that should be taken in consideration when
-# showing a hint for a missing member.
-missing-member-max-choices=1
-
-
-[VARIABLES]
-
-# List of additional names supposed to be defined in builtins. Remember that
-# you should avoid to define new builtins when possible.
-additional-builtins=
-
-# Tells whether unused global variables should be treated as a violation.
-allow-global-unused-variables=yes
-
-# List of strings which can identify a callback function by name. A callback
-# name must start or end with one of those strings.
-callbacks=cb_,
-          _cb
-
-# A regular expression matching the name of dummy variables (i.e. expectedly
-# not used).
-dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
-
-# Argument names that match this expression will be ignored. Default to name
-# with leading underscore
-ignored-argument-names=_.*|^ignored_|^unused_
-
-# Tells whether we should check for unused import in __init__ files.
-init-import=no
-
-# List of qualified module names which can have objects that can redefine
-# builtins.
-redefining-builtins-modules=six.moves,past.builtins,future.builtins
-
-
-[FORMAT]
-
-# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
-expected-line-ending-format=
-
-# Regexp for a line that is allowed to be longer than the limit.
-ignore-long-lines=^\s*(# )?<?https?://\S+>?$
-
-# Number of spaces of indent required inside a hanging  or continued line.
-indent-after-paren=4
-
-# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
-# tab).
-indent-string='  '
-
-# Maximum number of characters on a single line.
-max-line-length=150
-
-# Maximum number of lines in a module
-max-module-lines=1000
-
-# Allow the body of a class to be on the same line as the declaration if body
-# contains single statement.
-single-line-class-stmt=no
-
-# Allow the body of an if to be on the same line as the test if there is no
-# else.
-single-line-if-stmt=no
-
-
-[BASIC]
-
-# Naming style matching correct argument names
-argument-naming-style=snake_case
-
-# Regular expression matching correct argument names. Overrides argument-
-# naming-style
-#argument-rgx=
-
-# Naming style matching correct attribute names
-attr-naming-style=snake_case
-
-# Regular expression matching correct attribute names. Overrides attr-naming-
-# style
-#attr-rgx=
-
-# Bad variable names which should always be refused, separated by a comma
-bad-names=foo,
-          bar,
-          baz,
-          toto,
-          tutu,
-          tata
-
-# Naming style matching correct class attribute names
-class-attribute-naming-style=any
-
-# Regular expression matching correct class attribute names. Overrides class-
-# attribute-naming-style
-#class-attribute-rgx=
-
-# Naming style matching correct class names
-class-naming-style=PascalCase
-
-# Regular expression matching correct class names. Overrides class-naming-style
-#class-rgx=
-
-# Naming style matching correct constant names
-const-naming-style=UPPER_CASE
-
-# Regular expression matching correct constant names. Overrides const-naming-
-# style
-#const-rgx=
-
-# Minimum line length for functions/classes that require docstrings, shorter
-# ones are exempt.
-docstring-min-length=-1
-
-# Naming style matching correct function names
-function-naming-style=snake_case
-
-# Regular expression matching correct function names. Overrides function-
-# naming-style
-#function-rgx=
-
-# Good variable names which should always be accepted, separated by a comma
-good-names=i,
-           j,
-           k,
-           ex,
-           Run,
-           _
-
-# Include a hint for the correct naming format with invalid-name
-include-naming-hint=no
-
-# Naming style matching correct inline iteration names
-inlinevar-naming-style=any
-
-# Regular expression matching correct inline iteration names. Overrides
-# inlinevar-naming-style
-#inlinevar-rgx=
-
-# Naming style matching correct method names
-method-naming-style=snake_case
-
-# Regular expression matching correct method names. Overrides method-naming-
-# style
-#method-rgx=
-
-# Naming style matching correct module names
-module-naming-style=snake_case
-
-# Regular expression matching correct module names. Overrides module-naming-
-# style
-#module-rgx=
-
-# Colon-delimited sets of names that determine each other's naming style when
-# the name regexes allow several styles.
-name-group=
-
-# Regular expression which should only match function or class names that do
-# not require a docstring.
-no-docstring-rgx=^_
-
-# List of decorators that produce properties, such as abc.abstractproperty. Add
-# to this list to register other decorators that produce valid properties.
-property-classes=abc.abstractproperty
-
-# Naming style matching correct variable names
-variable-naming-style=snake_case
-
-# Regular expression matching correct variable names. Overrides variable-
-# naming-style
-#variable-rgx=
-
-
-[DESIGN]
-
-# Maximum number of arguments for function / method
-max-args=5
-
-# Maximum number of attributes for a class (see R0902).
-max-attributes=7
-
-# Maximum number of boolean expressions in a if statement
-max-bool-expr=5
-
-# Maximum number of branch for function / method body
-max-branches=12
-
-# Maximum number of locals for function / method body
-max-locals=15
-
-# Maximum number of parents for a class (see R0901).
-max-parents=7
-
-# Maximum number of public methods for a class (see R0904).
-max-public-methods=20
-
-# Maximum number of return / yield for function / method body
-max-returns=6
-
-# Maximum number of statements in function / method body
-max-statements=50
-
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=2
-
-
-[CLASSES]
-
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,
-                      __new__,
-                      setUp
-
-# List of member names, which should be excluded from the protected access
-# warning.
-exclude-protected=_asdict,
-                  _fields,
-                  _replace,
-                  _source,
-                  _make
-
-# List of valid names for the first argument in a class method.
-valid-classmethod-first-arg=cls
-
-# List of valid names for the first argument in a metaclass class method.
-valid-metaclass-classmethod-first-arg=mcs
-
-
-[IMPORTS]
-
-# Allow wildcard imports from modules that define __all__.
-allow-wildcard-with-all=no
-
-# Analyse import fallback blocks. This can be used to support both Python 2 and
-# 3 compatible code, which means that the block might have code that exists
-# only in one or another interpreter, leading to false positives when analysed.
-analyse-fallback-blocks=no
-
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules=regsub,
-                   TERMIOS,
-                   Bastion,
-                   rexec
-
-# Create a graph of external dependencies in the given file (report RP0402 must
-# not be disabled)
-ext-import-graph=
-
-# Create a graph of every (i.e. internal and external) dependencies in the
-# given file (report RP0402 must not be disabled)
-import-graph=
-
-# Create a graph of internal dependencies in the given file (report RP0402 must
-# not be disabled)
-int-import-graph=
-
-# Force import order to recognize a module as part of the standard
-# compatibility libraries.
-known-standard-library=
-
-# Force import order to recognize a module as part of a third party library.
-known-third-party=enchant
-
-[STRING]
-
-# This flag controls whether the implicit-str-concat should generate a warning
-# on implicit string concatenation in sequences defined over several lines.
-check-str-concat-over-line-jumps=yes
-
-[EXCEPTIONS]
-
-# Exceptions that will emit a warning when being caught. Defaults to
-# "Exception"
-overgeneral-exceptions=builtins.Exception

+ 25 - 5
README.md

@@ -121,14 +121,14 @@ exo
 
 
 That's it! No configuration required - exo will automatically discover the other device(s).
 That's it! No configuration required - exo will automatically discover the other device(s).
 
 
-exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000
+exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:52415
 
 
-For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Examples with curl:
+For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:52415/v1/chat/completions. Examples with curl:
 
 
 #### Llama 3.2 3B:
 #### Llama 3.2 3B:
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llama-3.2-3b",
      "model": "llama-3.2-3b",
@@ -140,7 +140,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llama 3.1 405B:
 #### Llama 3.1 405B:
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llama-3.1-405b",
      "model": "llama-3.1-405b",
@@ -152,7 +152,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llava 1.5 7B (Vision Language Model):
 #### Llava 1.5 7B (Vision Language Model):
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llava-1.5-7b-hf",
      "model": "llava-1.5-7b-hf",
@@ -208,6 +208,12 @@ With a custom prompt:
 exo run llama-3.2-3b --prompt "What is the meaning of exo?"
 exo run llama-3.2-3b --prompt "What is the meaning of exo?"
 ```
 ```
 
 
+### Model Storage
+
+Models by default are stored in `~/.cache/huggingface/hub`.
+
+You can set a different model storage location by setting the `HF_HOME` env var.
+
 ## Debugging
 ## Debugging
 
 
 Enable debug logs with the DEBUG environment variable (0-9).
 Enable debug logs with the DEBUG environment variable (0-9).
@@ -222,6 +228,20 @@ For the **tinygrad** inference engine specifically, there is a separate DEBUG fl
 TINYGRAD_DEBUG=2 exo
 TINYGRAD_DEBUG=2 exo
 ```
 ```
 
 
+## Formatting
+
+We use [yapf](https://github.com/google/yapf) to format the code. To format the code, first install the formatting requirements:
+
+```sh
+pip3 install -e '.[formatting]'
+```
+
+Then run the formatting script:
+
+```sh
+python3 format.py ./exo
+```
+
 ## Known Issues
 ## Known Issues
 
 
 - On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
 - On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:

+ 18 - 2
configure_mlx.sh

@@ -1,2 +1,18 @@
-sudo sysctl iogpu.wired_lwm_mb=400000
-sudo sysctl iogpu.wired_limit_mb=180000
+#!/bin/bash
+
+# Get the total memory in MB
+TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
+
+# Set WIRED_LIMIT_MB to 80%
+WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
+# Set  WIRED_LWM_MB to 70%
+WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
+
+# Display the calculated values
+echo "Total memory: $TOTAL_MEM_MB MB"
+echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
+echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
+
+# Apply the values with sysctl
+sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
+sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB

BIN
docs/exo-rounded.png


+ 1 - 1
examples/astra/astra/ContentView.swift

@@ -148,7 +148,7 @@ struct ContentView: View {
     @State private var voiceActivityThreshold: Float = 0.40
     @State private var voiceActivityThreshold: Float = 0.40
     @State private var silenceTimeThreshold = 1.0
     @State private var silenceTimeThreshold = 1.0
     @State private var debugText = ""
     @State private var debugText = ""
-    @State private var apiEndpoint = "http://192.168.212.74:8000/v1/chat/completions"
+    @State private var apiEndpoint = "http://192.168.212.74:52415/v1/chat/completions"
     @State private var audioBuffer: [Float] = []
     @State private var audioBuffer: [Float] = []
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var isInitialTranscription = true
     @State private var isInitialTranscription = true

+ 1 - 1
examples/chatgpt_api.sh

@@ -3,7 +3,7 @@
 # This works the same in a single-node set up and in a multi-node setup.
 # This works the same in a single-node set up and in a multi-node setup.
 # You need to start exo before running this by running `python3 main.py`.
 # You need to start exo before running this by running `python3 main.py`.
 
 
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 MODEL="llama-3.1-8b"
 MODEL="llama-3.1-8b"
 PROMPT="What is the meaning of exo?"
 PROMPT="What is the meaning of exo?"
 TEMPERATURE=0.7
 TEMPERATURE=0.7

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 53 - 41
exo/api/chatgpt_api.py

@@ -8,15 +8,14 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
 import traceback
 import traceback
+import signal
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
-from exo.inference.shard import Shard
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
-from exo.models import model_base_shards
-from typing import Callable
-
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
+from typing import Callable, Optional
 
 
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -27,6 +26,7 @@ class Message:
     return {"role": self.role, "content": self.content}
     return {"role": self.role, "content": self.content}
 
 
 
 
+
 class ChatCompletionRequest:
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
     self.model = model
@@ -117,19 +117,11 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 def build_prompt(tokenizer, _messages: List[Message]):
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
   messages = remap_messages(_messages)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  image_str = None
   for message in messages:
   for message in messages:
     if not isinstance(message.content, list):
     if not isinstance(message.content, list):
       continue
       continue
 
 
-    for content in message.content:
-      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
-      # follows the convention in https://platform.openai.com/docs/guides/vision
-      if isinstance(content, dict) and content.get("type", None) == "image":
-        image_str = content.get("image", None)
-        break
-
-  return prompt, image_str
+  return prompt
 
 
 
 
 def parse_message(data: dict):
 def parse_message(data: dict):
@@ -138,9 +130,9 @@ def parse_message(data: dict):
   return Message(data["role"], data["content"])
   return Message(data["role"], data["content"])
 
 
 
 
-def parse_chat_request(data: dict):
+def parse_chat_request(data: dict, default_model: str):
   return ChatCompletionRequest(
   return ChatCompletionRequest(
-    data.get("model", "llama-3.1-8b"),
+    data.get("model", default_model),
     [parse_message(msg) for msg in data["messages"]],
     [parse_message(msg) for msg in data["messages"]],
     data.get("temperature", 0.0),
     data.get("temperature", 0.0),
   )
   )
@@ -152,9 +144,8 @@ class PromptSession:
     self.timestamp = timestamp
     self.timestamp = timestamp
     self.prompt = prompt
     self.prompt = prompt
 
 
-
 class ChatGPTAPI:
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
     self.response_timeout = response_timeout
@@ -163,6 +154,8 @@ class ChatGPTAPI:
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
+    self.default_model = default_model or "llama-3.2-1b"
+
     cors = aiohttp_cors.setup(self.app)
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
     cors_options = aiohttp_cors.ResourceOptions(
       allow_credentials=True,
       allow_credentials=True,
@@ -177,13 +170,24 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
+    cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
+    cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
 
-    self.static_dir = Path(__file__).parent.parent / "tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
 
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    if DEBUG>=1: print("Received quit signal")
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+    await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
 
 
   async def timeout_middleware(self, app, handler):
   async def timeout_middleware(self, app, handler):
     async def middleware(request):
     async def middleware(request):
@@ -191,6 +195,7 @@ class ChatGPTAPI:
         return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
         return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
       except asyncio.TimeoutError:
       except asyncio.TimeoutError:
         return web.json_response({"detail": "Request timed out"}, status=408)
         return web.json_response({"detail": "Request timed out"}, status=408)
+
     return middleware
     return middleware
 
 
   async def log_request(self, app, handler):
   async def log_request(self, app, handler):
@@ -203,14 +208,25 @@ class ChatGPTAPI:
   async def handle_root(self, request):
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
     return web.FileResponse(self.static_dir/"index.html")
 
 
+  async def handle_healthcheck(self, request):
+    return web.json_response({"status": "ok"})
+
+  async def handle_model_support(self, request):
+    return web.json_response({
+      "model pool": {
+        model_name: pretty_name.get(model_name, model_name) 
+        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
+      }
+    })
+  
   async def handle_get_models(self, request):
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True } for model_name, _ in model_base_shards.items()])
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
 
   async def handle_post_chat_token_encode(self, request):
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
     data = await request.json()
-    shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    shard = build_base_shard(self.default_model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     messages = [parse_message(msg) for msg in data.get("messages", [])]
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
 
 
   async def handle_get_download_progress(self, request):
   async def handle_get_download_progress(self, request):
@@ -222,29 +238,28 @@ class ChatGPTAPI:
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
     return web.json_response(progress_data)
     return web.json_response(progress_data)
 
 
-
   async def handle_post_chat_completions(self, request):
   async def handle_post_chat_completions(self, request):
     data = await request.json()
     data = await request.json()
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     stream = data.get("stream", False)
-    chat_request = parse_chat_request(data)
-    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-      chat_request.model = "llama-3.1-8b"
-    if not chat_request.model or chat_request.model not in model_base_shards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
-      chat_request.model = "llama-3.1-8b"
-    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
+    chat_request = parse_chat_request(data, self.default_model)
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
+      chat_request.model = self.default_model
+    if not chat_request.model or chat_request.model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      chat_request.model = self.default_model
+    shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     if not shard:
     if not shard:
-      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
       return web.json_response(
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
         status=400,
       )
       )
 
 
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
 
-    prompt, image_str = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
     if self.on_chat_completion_request:
       try:
       try:
@@ -267,13 +282,10 @@ class ChatGPTAPI:
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
     callback = self.node.on_token.register(callback_id)
 
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
 
 
     try:
     try:
-      await asyncio.wait_for(
-        asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))),
-        timeout=self.response_timeout
-      )
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
 
 
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 
 
@@ -356,7 +368,7 @@ class ChatGPTAPI:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
 
-  async def run(self, host: str = "0.0.0.0", port: int = 8000):
+  async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     runner = web.AppRunner(self.app)
     await runner.setup()
     await runner.setup()
     site = web.TCPSite(runner, host, port)
     site = web.TCPSite(runner, host, port)

+ 34 - 6
exo/download/hf/hf_helpers.py

@@ -1,7 +1,11 @@
+import aiofiles.os as aios
+from typing import Union
 import asyncio
 import asyncio
 import aiohttp
 import aiohttp
 import json
 import json
 import os
 import os
+import sys
+import shutil
 from urllib.parse import urljoin
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
@@ -9,7 +13,7 @@ from fnmatch import fnmatch
 from pathlib import Path
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 import aiofiles
 import aiofiles
@@ -17,7 +21,6 @@ from aiofiles import os as aios
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
   refs_file = refs_dir/revision
@@ -70,8 +73,10 @@ def _add_wildcard_to_directories(pattern: str) -> str:
     return pattern + "*"
     return pattern + "*"
   return pattern
   return pattern
 
 
+
 def get_hf_endpoint() -> str:
 def get_hf_endpoint() -> str:
-    return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+  return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+
 
 
 def get_hf_home() -> Path:
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
   """Get the Hugging Face home directory."""
@@ -97,9 +102,22 @@ async def get_auth_headers():
 
 
 def get_repo_root(repo_id: str) -> Path:
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
 
+async def move_models_to_hf(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)
+  async for path in source_dir.iterdir():
+    if path.is_dir() and path.startswith("models--"):
+      dest_path = dest_dir / path.name
+      if dest_path.exists():
+        if DEBUG>=1: print(f"skipping moving {dest_path}. File already exists")
+      else:
+        await aios.rename(str(path), str(dest_path))
+        
 
 
 async def fetch_file_list(session, repo_id, revision, path=""):
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
@@ -394,7 +412,7 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
 
 
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-  default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"])
+  default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   shard_specific_patterns = set()
   shard_specific_patterns = set()
   if weight_map:
   if weight_map:
     for tensor_name, filename in weight_map.items():
     for tensor_name, filename in weight_map.items():
@@ -407,6 +425,16 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     elif shard.is_last_layer():
     elif shard.is_last_layer():
       shard_specific_patterns.add(sorted_file_names[-1])
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
   else:
-    shard_specific_patterns = set("*.safetensors")
+    shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
   return list(default_patterns | shard_specific_patterns)
+
+async def has_hf_home_read_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.R_OK)
+  except OSError: return False
+
+async def has_hf_home_write_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.W_OK)
+  except OSError: return False

+ 8 - 6
exo/download/hf/hf_shard_download.py

@@ -7,6 +7,7 @@ from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import model_cards, get_repo
 
 
 
 
 class HFShardDownloader(ShardDownloader):
 class HFShardDownloader(ShardDownloader):
@@ -17,11 +18,12 @@ class HFShardDownloader(ShardDownloader):
     self.completed_downloads: Dict[Shard, Path] = {}
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
 
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
       return self.completed_downloads[shard]
     if self.quick_check:
     if self.quick_check:
-      repo_root = get_repo_root(shard.model_id)
+      repo_root = get_repo_root(repo_name)
       snapshots_dir = repo_root/"snapshots"
       snapshots_dir = repo_root/"snapshots"
       if snapshots_dir.exists():
       if snapshots_dir.exists():
         visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
         visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
@@ -51,7 +53,7 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
     self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
 
     # Start new download
     # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard))
+    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
     self.active_downloads[shard] = download_task
     self.active_downloads[shard] = download_task
     try:
     try:
       path = await download_task
       path = await download_task
@@ -63,14 +65,14 @@ class HFShardDownloader(ShardDownloader):
       if shard in self.active_downloads:
       if shard in self.active_downloads:
         self.active_downloads.pop(shard)
         self.active_downloads.pop(shard)
 
 
-  async def _download_shard(self, shard: Shard) -> Path:
+  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
     async def wrapped_progress_callback(event: RepoProgressEvent):
     async def wrapped_progress_callback(event: RepoProgressEvent):
       self._on_progress.trigger_all(shard, event)
       self._on_progress.trigger_all(shard, event)
 
 
-    weight_map = await get_weight_map(shard.model_id)
+    weight_map = await get_weight_map(repo_name)
     allow_patterns = get_allow_patterns(weight_map, shard)
     allow_patterns = get_allow_patterns(weight_map, shard)
 
 
-    return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
+    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
 
 
   @property
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:

+ 4 - 2
exo/download/shard_download.py

@@ -8,7 +8,7 @@ from exo.helpers import AsyncCallbackSystem
 
 
 class ShardDownloader(ABC):
 class ShardDownloader(ABC):
   @abstractmethod
   @abstractmethod
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     """
     """
         Ensures that the shard is downloaded.
         Ensures that the shard is downloaded.
         Does not allow multiple overlapping downloads at once.
         Does not allow multiple overlapping downloads at once.
@@ -17,6 +17,7 @@ class ShardDownloader(ABC):
 
 
         Args:
         Args:
             shard (Shard): The shard to download.
             shard (Shard): The shard to download.
+            inference_engine_name (str): The inference engine used on the node hosting the shard
         """
         """
     pass
     pass
 
 
@@ -25,8 +26,9 @@ class ShardDownloader(ABC):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
     pass
 
 
+
 class NoopShardDownloader(ShardDownloader):
 class NoopShardDownloader(ShardDownloader):
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     return Path("/tmp/noop_shard")
     return Path("/tmp/noop_shard")
 
 
   @property
   @property

+ 21 - 1
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
 import os
+import sys
 import asyncio
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
 import socket
@@ -170,7 +171,7 @@ def is_valid_uuid(val):
 
 
 
 
 def get_or_create_node_id():
 def get_or_create_node_id():
-  NODE_ID_FILE = Path(tempfile.gettempdir()) / ".exo_node_id"
+  NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
   try:
   try:
     if NODE_ID_FILE.is_file():
     if NODE_ID_FILE.is_file():
       with open(NODE_ID_FILE, "r") as f:
       with open(NODE_ID_FILE, "r") as f:
@@ -234,3 +235,22 @@ def get_all_ip_addresses():
   except:
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
     return ["localhost"]
+
+
+async def shutdown(signal, loop, server):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+  loop.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 11 - 12
exo/inference/debug_inference_engine.py

@@ -13,32 +13,31 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
   _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
 
 
   prompt = "In a single word only, what is the last name of the president of the United States? "
   prompt = "In a single word only, what is the last name of the president of the United States? "
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
-    input_data=resp_full,
-    inference_state=inference_state_full,
+    input_data=token_full,
   )
   )
 
 
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp1,
     input_data=resp1,
-    inference_state=inference_state_1,
   )
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  token2 = await inference_engine_2.sample(resp2)
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
-    input_data=resp2,
-    inference_state=inference_state_2,
+    input_data=token2,
   )
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,
     input_data=resp3,
-    inference_state=inference_state_3,
   )
   )
 
 
   print(f"{resp2=}")
   print(f"{resp2=}")

+ 17 - 38
exo/inference/dummy_inference_engine.py

@@ -1,59 +1,38 @@
 from typing import Optional, Tuple, TYPE_CHECKING
 from typing import Optional, Tuple, TYPE_CHECKING
 import numpy as np
 import numpy as np
+import random
+import string
 import asyncio
 import asyncio
 import json
 import json
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+def random_string(length: int):
+  return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
+  
 
 
 class DummyInferenceEngine(InferenceEngine):
 class DummyInferenceEngine(InferenceEngine):
   def __init__(self):
   def __init__(self):
     self.shard = None
     self.shard = None
     self.vocab_size = 1000
     self.vocab_size = 1000
+    self.hidden_size = 256
     self.eos_token_id = 0
     self.eos_token_id = 0
     self.latency_mean = 0.1
     self.latency_mean = 0.1
     self.latency_stddev = 0.02
     self.latency_stddev = 0.02
 
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-    try:
-      await self.ensure_shard(shard)
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
+  
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size)
 
 
-      # Generate random tokens
-      output_length = np.random.randint(1, 10)
-      output = np.random.randint(1, self.vocab_size, size=(1, output_length))
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
+    return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
 
 
-      # Simulate latency
-      await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-
-      # Randomly decide if finished
-      is_finished = np.random.random() < 0.2
-      if is_finished:
-        output = np.array([[self.eos_token_id]])
-
-      new_state = json.dumps({"dummy_state": "some_value"})
-
-      return output, new_state, is_finished
-    except Exception as e:
-      print(f"Error in DummyInferenceEngine.infer_prompt: {str(e)}")
-      return np.array([[self.eos_token_id]]), json.dumps({"error": str(e)}), True
-
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    state = json.loads(inference_state or "{}")
-    start_pos = state.get("start_pos", 0)
-
-    output_length = np.random.randint(1, 10)
-    output = np.random.randint(1, self.vocab_size, size=(1, output_length))
-
-    await asyncio.sleep(max(0, np.random.normal(self.latency_mean, self.latency_stddev)))
-
-    is_finished = np.random.random() < 0.2
-    if is_finished:
-      output = np.array([[self.eos_token_id]])
-
-    start_pos += input_data.shape[1] + output_length
-    new_state = json.dumps({"start_pos": start_pos})
-
-    return output, new_state, is_finished
+    sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
+    output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))
+    return output
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:

+ 22 - 3
exo/inference/inference_engine.py

@@ -9,13 +9,32 @@ from .shard import Shard
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    pass
+  
+  @abstractmethod
+  async def sample(self, x: np.ndarray) -> np.ndarray:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     pass
     pass
 
 
+  @abstractmethod
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    pass
+  
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
+    tokens = await self.encode(shard, prompt)
+    x = tokens.reshape(1, -1)
+    output_data = await self.infer_tensor(request_id, shard, x)
+    return output_data 
+
+inference_engine_classes = {
+  "mlx": "MLXDynamicShardInferenceEngine",
+  "tinygrad": "TinygradDynamicShardInferenceEngine",
+  "dummy": "DummyInferenceEngine",
+}
 
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
   if DEBUG >= 2:
   if DEBUG >= 2:
@@ -33,4 +52,4 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
   elif inference_engine_name == "dummy":
   elif inference_engine_name == "dummy":
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
     from exo.inference.dummy_inference_engine import DummyInferenceEngine
     return DummyInferenceEngine()
     return DummyInferenceEngine()
-  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 1 - 1
exo/inference/mlx/models/base.py

@@ -1,7 +1,7 @@
 from typing import Optional
 from typing import Optional
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 
 
 
 
 class IdentityBlock(nn.Module):
 class IdentityBlock(nn.Module):

+ 1 - 1
exo/inference/mlx/models/deepseek_v2.py

@@ -4,7 +4,7 @@ from typing import Optional
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 
 
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from .base import IdentityBlock
 from .base import IdentityBlock
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard

+ 118 - 0
exo/inference/mlx/models/gemma2.py

@@ -0,0 +1,118 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.gemma2 import TransformerBlock, ModelArgs, RMSNorm
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class GemmaModel(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if args.shard.is_first_layer() or args.shard.is_last_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if args.shard.start_layer <= i <= args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+    if args.shard.is_last_layer():
+      self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+      h = h * (self.args.hidden_size**0.5)
+    else:
+      h = inputs
+
+    mask = None
+    if h.ndim > 1 and h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = GemmaModel(args)
+    if args.shard.is_last_layer():
+      self.final_logit_softcapping = args.final_logit_softcapping
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      out = self.model.embed_tokens.as_linear(out)
+      out = mx.tanh(out / self.final_logit_softcapping)
+      out = out * self.final_logit_softcapping
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif (self.args.shard.is_first_layer() or self.args.shard.is_last_layer()) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.head_dim
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 2 - 1
exo/inference/mlx/models/qwen2.py

@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
 
 
     self.shard = Shard(**self.shard)
     self.shard = Shard(**self.shard)
 
 
+
 class Qwen2Model(nn.Module):
 class Qwen2Model(nn.Module):
   def __init__(self, args: ModelArgs):
   def __init__(self, args: ModelArgs):
     super().__init__()
     super().__init__()
@@ -57,7 +58,7 @@ class Qwen2Model(nn.Module):
       mask = create_attention_mask(h, cache)
       mask = create_attention_mask(h, cache)
 
 
     if cache is None:
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
 
     for layer, c in zip(self.layers, cache):
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, c)
       h = layer(h, mask, c)

+ 46 - 21
exo/inference/mlx/sharded_inference_engine.py

@@ -1,14 +1,35 @@
 import numpy as np
 import numpy as np
 import mlx.core as mx
 import mlx.core as mx
+import mlx.nn as nn
 from ..inference_engine import InferenceEngine
 from ..inference_engine import InferenceEngine
-from .sharded_model import StatefulShardedModel
+from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from ..shard import Shard
-from typing import Optional
+from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 import asyncio
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 from functools import partial
+def sample_logits(
+  logits: mx.array,
+  temp: float = 0.0,
+  top_p: float = 1.0,
+  logit_bias: Optional[Dict[int, float]] = None
+) -> Tuple[mx.array, float]:
+  if logit_bias:
+    indices = mx.array(list(logit_bias.keys()))
+    values = mx.array(list(logit_bias.values()))
+    logits[:, indices] += values
+
+  if temp == 0:
+    token = mx.argmax(logits, axis=-1)
+  else:
+    if top_p > 0 and top_p < 1.0:
+      token = top_p_sampling(logits, top_p, temp)
+    else:
+      token = mx.random.categorical(logits*(1/temp))
+
+  return token
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -16,35 +37,39 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.executor = ThreadPoolExecutor(max_workers=1)
 
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
+    y = mx.array(x)
+    logits = y[:, -1, :]
+    out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
+    return out
+
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
-    if image_str:
-      image = await get_image_from_str(image_str)
-      tokenize = partial(self.tokenizer, prompt, image, return_tensors="np")
-      inputs = await loop.run_in_executor(self.executor, tokenize)
-      pixel_values = mx.array(inputs["pixel_values"])
-      input_ids = mx.array(inputs["input_ids"])
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
-    else:
-      input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return np.array(tokens)
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
+    
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
+    return output_data
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
 
     if self.shard != shard:
     if self.shard != shard:
       loop = asyncio.get_running_loop()
       loop = asyncio.get_running_loop()
-      def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
+
+      def load_shard_wrapper():
+        return asyncio.run(load_shard(model_path, shard))
+
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
-      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 0 - 86
exo/inference/mlx/sharded_model.py

@@ -1,86 +0,0 @@
-from typing import Dict, Generator, Optional, Tuple
-from collections import OrderedDict
-
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.models.base import KVCache, RotatingKVCache
-from mlx_lm.sample_utils import top_p_sampling
-
-from ..shard import Shard
-
-# TODO: support a speculative model so we can parallelise compute across devices
-class StatefulShardedModel:
-  def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
-    self.shard = shard
-    self.model = model
-    self.max_kv_size = max_kv_size
-    self.max_caches = max_caches
-    self.caches = OrderedDict()
-
-  def step(
-    self,
-    request_id: str,
-    x,
-    pixel_values=None,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    def sample(logits: mx.array) -> Tuple[mx.array, float]:
-      if logit_bias:
-        indices = mx.array(list(logit_bias.keys()))
-        values = mx.array(list(logit_bias.values()))
-        logits[:, indices] += values
-
-      if temp == 0:
-        token = mx.argmax(logits, axis=-1)
-      else:
-        if top_p > 0 and top_p < 1.0:
-          token = top_p_sampling(logits, top_p, temp)
-        else:
-          token = mx.random.categorical(logits*(1/temp))
-
-      return token
-
-    y = x
-
-    if request_id not in self.caches:
-      self.init_cache(request_id)
-    else:
-      self.caches.move_to_end(request_id)
-
-    cache = self.caches[request_id]
-
-    if pixel_values is None:
-      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
-    else:
-      output = self.model(y, pixel_values=pixel_values, cache=cache)
-
-    if self.shard.is_last_layer():
-      logits = output[:, -1, :]
-      y = sample(logits)
-      return y
-    else:
-      return output
-
-  def __call__(
-    self,
-    request_id: str,
-    x,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
-
-  def init_cache(self, request_id: str):
-    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
-    if self.max_kv_size is not None:
-      cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
-    else:
-      cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
-
-    self.caches[request_id] = cache

+ 14 - 8
exo/inference/mlx/sharded_utils.py

@@ -12,15 +12,16 @@ from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from PIL import Image
 from io import BytesIO
 from io import BytesIO
 import base64
 import base64
+import traceback
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from transformers import AutoProcessor
 from transformers import AutoProcessor
 
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
-from mlx_lm.tuner.utils import apply_lora_layers
 
 
 from exo import DEBUG
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 from ..shard import Shard
 
 
 
 
@@ -53,6 +54,7 @@ def _get_classes(config: dict):
   except ImportError:
   except ImportError:
     msg = f"Model type {model_type} not supported."
     msg = f"Model type {model_type} not supported."
     logging.error(msg)
     logging.error(msg)
+    traceback.print_exc()
     raise ValueError(msg)
     raise ValueError(msg)
 
 
   return arch.Model, arch.ModelArgs
   return arch.Model, arch.ModelArgs
@@ -67,7 +69,6 @@ def load_config(model_path: Path) -> dict:
     raise
     raise
   return config
   return config
 
 
-
 def load_model_shard(
 def load_model_shard(
   model_path: Path,
   model_path: Path,
   shard: Shard,
   shard: Shard,
@@ -130,8 +131,17 @@ def load_model_shard(
 
 
   model_class, model_args_class = _get_classes(config=config)
   model_class, model_args_class = _get_classes(config=config)
 
 
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x, *args, **kwargs)
+      return y
+
   model_args = model_args_class.from_dict(config)
   model_args = model_args_class.from_dict(config)
-  model = model_class(model_args)
+  model = ShardedModel(model_args)
 
 
   if hasattr(model, "sanitize"):
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
     weights = model.sanitize(weights)
@@ -157,7 +167,6 @@ def load_model_shard(
   model.eval()
   model.eval()
   return model
   return model
 
 
-
 async def load_shard(
 async def load_shard(
   model_path: str,
   model_path: str,
   shard: Shard,
   shard: Shard,
@@ -167,9 +176,6 @@ async def load_shard(
   lazy: bool = False,
   lazy: bool = False,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
 ) -> Tuple[nn.Module, TokenizerWrapper]:
   model = load_model_shard(model_path, shard, lazy, model_config)
   model = load_model_shard(model_path, shard, lazy, model_config)
-  if adapter_path is not None:
-    model = apply_lora_layers(model, adapter_path)
-    model.eval()
 
 
   # TODO: figure out a generic solution
   # TODO: figure out a generic solution
   if model.model_type == "llava":
   if model.model_type == "llava":
@@ -178,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     processor.encode = processor.tokenizer.encode
     return model, processor
     return model, processor
   else:
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
     return model, tokenizer
 
 
 
 

+ 42 - 0
exo/inference/mlx/stateful_model.py

@@ -0,0 +1,42 @@
+from typing import Dict, Tuple
+from collections import OrderedDict
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.cache import make_prompt_cache
+
+from ..shard import Shard
+
+class StatefulModel(nn.Module):
+  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_kv_size = max_kv_size
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
+  
+  def init_cache(self, request_id: str):
+    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
+    # if self.max_kv_size is not None:
+      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    # else:
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    cache = make_prompt_cache(self.model)
+
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, x, request_id: str):
+    if request_id not in self.caches:
+      self.init_cache(request_id)
+    else:
+      self.caches.move_to_end(request_id)
+
+    cache = self.caches[request_id]
+
+    y = self.model(x, cache=cache)
+    return y
+    

+ 4 - 4
exo/inference/mlx/test_sharded_llama.py

@@ -1,5 +1,5 @@
 import mlx.core as mx
 import mlx.core as mx
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
@@ -12,9 +12,9 @@ full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Ins
 model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
 model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
 model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
 model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
 
 
-full = StatefulShardedModel(shard_full, full_model_shard)
-m1 = StatefulShardedModel(shard1, model_shard1)
-m2 = StatefulShardedModel(shard2, model_shard2)
+full = StatefulModel(shard_full, full_model_shard)
+m1 = StatefulModel(shard1, model_shard1)
+m2 = StatefulModel(shard2, model_shard2)
 
 
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt_tokens = mx.array(full_tokenizer.encode(prompt))
 prompt_tokens = mx.array(full_tokenizer.encode(prompt))

+ 2 - 2
exo/inference/mlx/test_sharded_llava.py

@@ -5,9 +5,9 @@ from PIL import Image
 from io import BytesIO
 from io import BytesIO
 
 
 import mlx.core as mx
 import mlx.core as mx
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 
 
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 

+ 39 - 42
exo/inference/test_dummy_inference_engine.py

@@ -4,53 +4,50 @@ import numpy as np
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
+
 class MockShardDownloader:
 class MockShardDownloader:
-    async def ensure_shard(self, shard):
-        pass
+  async def ensure_shard(self, shard):
+    pass
+
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dummy_inference_specific():
 async def test_dummy_inference_specific():
-    engine = DummyInferenceEngine(MockShardDownloader())
-    test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    test_prompt = "This is a test prompt"
-    
-    result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
-    
-    print(f"Inference result shape: {result.shape}")
-    print(f"Inference state: {state}")
-    print(f"Is finished: {is_finished}")
-    
-    assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
-    assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+  engine = DummyInferenceEngine(MockShardDownloader())
+  test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+  test_prompt = "This is a test prompt"
+
+  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
+
+  print(f"Inference result shape: {result.shape}")
+
+  assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
+
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dummy_inference_engine():
 async def test_dummy_inference_engine():
-    # Initialize the DummyInferenceEngine
-    engine = DummyInferenceEngine(MockShardDownloader())
-    
-    # Create a test shard
-    shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    
-    # Test infer_prompt
-    output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
-    
-    assert isinstance(output, np.ndarray), "Output should be a numpy array"
-    assert output.ndim == 2, "Output should be 2-dimensional"
-    assert isinstance(state, str), "State should be a string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
-
-    # Test infer_tensor
-    input_tensor = np.array([[1, 2, 3]])
-    output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
-    
-    assert isinstance(output, np.ndarray), "Output should be a numpy array"
-    assert output.ndim == 2, "Output should be 2-dimensional"
-    assert isinstance(state, str), "State should be a string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
-
-    print("All tests passed!")
+  # Initialize the DummyInferenceEngine
+  engine = DummyInferenceEngine(MockShardDownloader())
+
+  # Create a test shard
+  shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+
+  # Test infer_prompt
+  output = await engine.infer_prompt("test_id", shard, "Test prompt")
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+
+  # Test infer_tensor
+  input_tensor = np.array([[1, 2, 3]])
+  output = await engine.infer_tensor("test_id", shard, input_tensor)
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+
+  print("All tests passed!")
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    import asyncio
-    asyncio.run(test_dummy_inference_engine())
-    asyncio.run(test_dummy_inference_specific())
+  import asyncio
+  asyncio.run(test_dummy_inference_engine())
+  asyncio.run(test_dummy_inference_specific())

+ 14 - 24
exo/inference/test_inference_engine.py

@@ -11,45 +11,40 @@ import numpy as np
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+  token_full = token_full.reshape(1, -1)
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
-    input_data=resp_full,
-    inference_state=inference_state_full,
+    input_data=token_full,
   )
   )
 
 
   pp = n_layers // 2
   pp = n_layers // 2
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
     input_data=resp1,
-    inference_state=inference_state_1,
   )
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  tokens2 = await inference_engine_1.sample(resp2)
+  tokens2 = tokens2.reshape(1, -1)
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
-    input_data=resp2,
-    inference_state=inference_state_2,
+    input_data=tokens2,
   )
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,
     input_data=resp3,
-    inference_state=inference_state_3,
   )
   )
 
 
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(next_resp_full, resp4)
   assert np.array_equal(next_resp_full, resp4)
 
 
 
 
-asyncio.run(test_inference_engine(
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  "mlx-community/Llama-3.2-1B-Instruct-4bit",
-  16
-))
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
 
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
   import tinygrad
@@ -57,10 +52,5 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   asyncio.run(
   asyncio.run(
-    test_inference_engine(
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-      32
-    )
+    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
   )
   )

+ 34 - 36
exo/inference/tinygrad/inference.py

@@ -1,17 +1,17 @@
 from pathlib import Path
 from pathlib import Path
 import json
 import json
 import os
 import os
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
 from tinygrad.nn.state import load_state_dict
 from tinygrad import Tensor, nn, Context
 from tinygrad import Tensor, nn, Context
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
-from typing import Optional, Tuple
 import numpy as np
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
+from .stateful_model import StatefulModel
 import asyncio
 import asyncio
 
 
 Tensor.no_grad = True
 Tensor.no_grad = True
@@ -22,7 +22,17 @@ TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 ALPHA_P = 0.0
 MODEL_PARAMS = {
 MODEL_PARAMS = {
-  "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
+  "1B": {
+    "args": {
+      "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "3B": {
+    "args": {
+      "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
   "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
   "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
 }
 }
 
 
@@ -30,8 +40,7 @@ MODEL_PARAMS = {
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   # build model
   linear = nn.Linear
   linear = nn.Linear
-  with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
+  model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
 
   # load weights
   # load weights
   if model_path.is_dir():
   if model_path.is_dir():
@@ -48,54 +57,43 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
   return model
   return model
 
 
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.executor = ThreadPoolExecutor(max_workers=1)
 
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
-
-    if h.shape == (1,):
-      start_pos += len(toks)
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      n_captured_toks = len(toks)
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+  async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
+    logits = x[:, -1, :]
+    def sample_wrapper():
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
+  
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
 
 
-    if h.shape == (1,):
-      start_pos += n_captured_toks
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    await self.ensure_shard(shard)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
 
     if self.shard != shard:
     if self.shard != shard:
-      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
+      loop = asyncio.get_running_loop()
+      parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
+      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
 
 
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 61 - 36
exo/inference/tinygrad/models/llama.py

@@ -1,11 +1,23 @@
-from typing import Tuple, Union, Optional, Dict, Any
+from typing import Tuple, Union, Optional, Dict, Any, List
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
+from collections import OrderedDict
 
 
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
   freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
   freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
+
+  if rope_scaling:
+    factor = rope_scaling.get('factor', 1.0)
+    low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
+    high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
+    original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
+
+    freqs[:dim // 4] *= low_freq_factor
+    freqs[dim // 4:] = freqs[dim // 4:].contiguous()*high_freq_factor
+    freqs *= (original_max_pos_emb/end)**(1.0/factor)
+
   freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   # TODO: move dtype outside this
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
@@ -36,7 +48,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
 
 
-
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
@@ -50,7 +61,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
     if getenv("WQKV"):
     if getenv("WQKV"):
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       xqkv = x @ self.wqkv.T
       xqkv = x @ self.wqkv.T
@@ -65,19 +76,16 @@ class Attention:
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     bsz, seqlen, _, _ = xq.shape
     bsz, seqlen, _, _ = xq.shape
 
 
-    # create kv cache
-    if not hasattr(self, "cache_kv"):
-      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
-      if isinstance(x.device, tuple):
-        # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+    if cache is not None:
+      # update the cache
+      assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
+      cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    # update the cache
-    assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
-
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+      keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+      values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    else:
+      keys = xk
+      values = xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -103,13 +111,13 @@ class TransformerBlock:
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
-    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
+    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
 
 
 # standard openai sampling
 # standard openai sampling
-def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
+def sample_logits(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert logits.ndim == 1, "only works on 1d tensors"
   assert 0 <= p <= 1, "p must be between 0 and 1"
   assert 0 <= p <= 1, "p must be between 0 and 1"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
@@ -176,42 +184,56 @@ class Transformer:
     rope_theta=10000,
     rope_theta=10000,
     max_context=1024,
     max_context=1024,
     jit=True,
     jit=True,
-    feed_forward=FeedForward
+    feed_forward=FeedForward,
+    rope_scaling: Optional[Dict[str, float]] = None,
+    tie_word_embeddings=False,
   ):
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.output = nn.Linear(dim, vocab_size, bias=False)
+    if tie_word_embeddings:
+      self.output.weight = self.tok_embeddings.weight
     self.max_context = max_context
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
-    self.forward_jit = TinyJit(self.forward) if jit else None
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
+    self.forward_jit = TinyJit(self.forward_base) if jit else None
     self.shard = shard
     self.shard = shard
 
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
+  def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
     seqlen = x.shape[1]
     seqlen = x.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
 
-    if self.shard.is_first_layer():
-      h = self.tok_embeddings(x)
-    else:
-      h = x
+    h = x
 
 
-    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+    if cache is None:
+      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
+    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
       layer = self.layers[i]
       layer = self.layers[i]
-      h = layer(h, start_pos, freqs_cis, mask)
+      h = layer(h, start_pos, freqs_cis, mask, cache=c)
 
 
     if self.shard.is_last_layer():
     if self.shard.is_last_layer():
-      logits = self.output(self.norm(h)).float()[:, -1, :]
-      return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+      logits = self.output(self.norm(h)).float().realize()
+      return logits
     else:
     else:
       return h
       return h
 
 
-  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
+  def embed(self, inputs: Tensor):
+    if self.shard.is_first_layer():
+      h = self.tok_embeddings(inputs)
+    else:
+      h = inputs
+    return h
+
+  def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
+    if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
+      return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
+    return self.forward_base(x, start_pos, cache=cache)
+
+  def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
-    return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
+    h = self.embed(x)
+    return self.forward(h, start_pos, cache=cache)
 
 
 
 
 # *** helpers ***
 # *** helpers ***
@@ -245,7 +267,10 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
         v = permute(v, n_heads)
         v = permute(v, n_heads)
       elif "k_proj" in k:
       elif "k_proj" in k:
         v = permute(v, n_kv_heads)
         v = permute(v, n_kv_heads)
-    sd[keymap[k]] = v
+    if k in keymap:
+      sd[keymap[k]] = v
+    else:
+      sd[k] = v
   return sd
   return sd
 
 
 
 

+ 42 - 0
exo/inference/tinygrad/stateful_model.py

@@ -0,0 +1,42 @@
+from tinygrad import Tensor, Variable 
+from collections import OrderedDict
+from typing import List
+
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
+  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+  if isinstance(x.device, tuple):
+    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+  return cache_kv.realize()
+
+class ModelState:
+  cache: List[Tensor]
+  start: int 
+  def __init__(self, cache: List[Tensor], start: int = 0):
+    self.cache = cache
+    self.start = start
+
+class StatefulModel:
+  def __init__(self, model, max_states: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_states = max_states
+    self.states = OrderedDict()
+ 
+  def init_cache(self, x: Tensor, request_id: str):
+    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
+    if len(self.states) >= self.max_states:
+      self.states.popitem(last=False)
+
+    self.states[request_id] = ModelState(cache)
+
+  def __call__(self, x: Tensor, request_id: str): 
+    h = self.model.embed(x)
+    if request_id not in self.states:
+      self.init_cache(h, request_id)
+    else:
+      self.states.move_to_end(request_id)
+    out = self.model.forward(h, self.states[request_id].start, cache=self.states[request_id].cache)
+    self.states[request_id].start += h.shape[1]
+    return out
+

+ 6 - 1
exo/inference/tokenizers.py

@@ -7,14 +7,18 @@ from transformers import AutoTokenizer, AutoProcessor
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
 
 
+
 class DummyTokenizer:
 class DummyTokenizer:
   def __init__(self):
   def __init__(self):
     self.eos_token_id = 0
     self.eos_token_id = 0
+
   def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
   def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
-    return [1,2,3]
+    return [1, 2, 3]
+
   def decode(self, tokens):
   def decode(self, tokens):
     return "dummy"
     return "dummy"
 
 
+
 async def resolve_tokenizer(model_id: str):
 async def resolve_tokenizer(model_id: str):
   if model_id == "dummy":
   if model_id == "dummy":
     return DummyTokenizer()
     return DummyTokenizer()
@@ -29,6 +33,7 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 5: traceback.print_exc()
     if DEBUG >= 5: traceback.print_exc()
   return await _resolve_tokenizer(model_id)
   return await _resolve_tokenizer(model_id)
 
 
+
 async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
 async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
   try:
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")

+ 76 - 40
exo/main.py

@@ -1,8 +1,12 @@
 import argparse
 import argparse
 import asyncio
 import asyncio
+import atexit
 import signal
 import signal
 import json
 import json
 import logging
 import logging
+import platform
+import os
+import sys
 import time
 import time
 import traceback
 import traceback
 import uuid
 import uuid
@@ -17,22 +21,24 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWe
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
-from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.orchestration.node import Node
-from exo.models import model_base_shards
+from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
+parser.add_argument("--default-model", type=str, default=None, help="Default model")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
@@ -42,7 +48,7 @@ parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale",
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
-parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
+parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=900, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=900, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
 parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
@@ -54,14 +60,13 @@ parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name
 args = parser.parse_args()
 args = parser.parse_args()
 print(f"Selected inference engine: {args.inference_engine}")
 print(f"Selected inference engine: {args.inference_engine}")
 
 
-
 print_yellow_exo()
 print_yellow_exo()
 
 
-
 system_info = get_system_info()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 print(f"Detected system: {system_info}")
 
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
+                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
 
@@ -84,9 +89,23 @@ if DEBUG >= 0:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
 
 if args.discovery_module == "udp":
 if args.discovery_module == "udp":
-  discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
+  discovery = UDPDiscovery(
+    args.node_id,
+    args.node_port,
+    args.listen_port,
+    args.broadcast_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout
+  )
 elif args.discovery_module == "tailscale":
 elif args.discovery_module == "tailscale":
-  discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
+  discovery = TailscaleDiscovery(
+    args.node_id,
+    args.node_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout,
+    tailscale_api_key=args.tailscale_api_key,
+    tailnet=args.tailnet_name
+  )
 elif args.discovery_module == "manual":
 elif args.discovery_module == "manual":
   if not args.discovery_config_path:
   if not args.discovery_config_path:
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
     raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
@@ -108,22 +127,26 @@ api = ChatGPTAPI(
   node,
   node,
   inference_engine.__class__.__name__,
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
   response_timeout=args.chatgpt_api_response_timeout,
-  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
+  default_model=args.default_model
 )
 )
 node.on_token.register("update_topology_viz").on_next(
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
 )
+
 def preemptively_start_download(request_id: str, opaque_status: str):
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
   try:
     status = json.loads(opaque_status)
     status = json.loads(opaque_status)
     if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
     if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
       current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
       current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
       if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
       if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-      asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
   except Exception as e:
   except Exception as e:
     if DEBUG >= 2:
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
       print(f"Failed to preemptively start download: {e}")
       traceback.print_exc()
       traceback.print_exc()
+
+
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 
 
 if args.prometheus_client_port:
 if args.prometheus_client_port:
@@ -132,38 +155,24 @@ if args.prometheus_client_port:
 
 
 last_broadcast_time = 0
 last_broadcast_time = 0
 
 
-def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-    global last_broadcast_time
-    current_time = time.time()
-    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-        last_broadcast_time = current_time
-        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
-            "type": "download_progress",
-            "node_id": node.id,
-            "progress": event.to_dict()
-        })))
 
 
-shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
+def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
 
 
 
 
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
+shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  inference_class = inference_engine.__class__.__name__
+  shard = build_base_shard(model_name, inference_class)
   if not shard:
   if not shard:
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
     return
-  tokenizer = await resolve_tokenizer(shard.model_id)
+  tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   request_id = str(uuid.uuid4())
   request_id = str(uuid.uuid4())
   callback_id = f"cli-wait-response-{request_id}"
   callback_id = f"cli-wait-response-{request_id}"
   callback = node.on_token.register(callback_id)
   callback = node.on_token.register(callback_id)
@@ -173,7 +182,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 
 
   try:
   try:
     print(f"Processing prompt: {prompt}")
     print(f"Processing prompt: {prompt}")
-    await node.process_prompt(shard, prompt, None, request_id=request_id)
+    await node.process_prompt(shard, prompt, request_id=request_id)
 
 
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
 
 
@@ -189,12 +198,38 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 async def main():
 async def main():
   loop = asyncio.get_running_loop()
   loop = asyncio.get_running_loop()
 
 
+  # Check HuggingFace directory permissions
+  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
+  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  print(f"{has_read=}, {has_write=}")
+  if not has_read or not has_write:
+    print(f"""
+          WARNING: Limited permissions for model storage directory: {hf_home}.
+          This may prevent model downloads from working correctly.
+          {"❌ No read access" if not has_read else ""}
+          {"❌ No write access" if not has_write else ""}
+          """)
+    
+  if not args.models_seed_dir is None:
+    try:
+      await move_models_to_hf(args.models_seed_dir)
+    except Exception as e:
+      print(f"Error moving models to .cache/huggingface: {e}")
+
+  def restore_cursor():
+    if platform.system() != "Windows":
+        os.system("tput cnorm")  # Show cursor
+
+  # Restore the cursor when the program exits
+  atexit.register(restore_cursor)
+
   # Use a more direct approach to handle signals
   # Use a more direct approach to handle signals
   def handle_exit():
   def handle_exit():
-    asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
+    asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server))
 
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
 
   await node.start(wait_for_peers=args.wait_for_peers)
   await node.start(wait_for_peers=args.wait_for_peers)
 
 
@@ -217,8 +252,9 @@ def run():
   except KeyboardInterrupt:
   except KeyboardInterrupt:
     print("Received keyboard interrupt. Shutting down...")
     print("Received keyboard interrupt. Shutting down...")
   finally:
   finally:
-    loop.run_until_complete(shutdown(signal.SIGTERM, loop))
+    loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
     loop.close()
     loop.close()
 
 
+
 if __name__ == "__main__":
 if __name__ == "__main__":
   run()
   run()

+ 125 - 50
exo/models.py

@@ -1,73 +1,148 @@
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from typing import Optional, List
 
 
-model_base_shards = {
+model_cards = {
   ### llama
   ### llama
   "llama-3.2-1b": {
   "llama-3.2-1b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
+    "layers": 16,
+    "repo": {
+      "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
+      "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+    },
   },
   },
   "llama-3.2-3b": {
   "llama-3.2-3b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
+    "layers": 28,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+    },
   },
   },
   "llama-3.1-8b": {
   "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
+    },
   },
   },
   "llama-3.1-70b": {
   "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
   },
   "llama-3.1-70b-bf16": {
   "llama-3.1-70b-bf16": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
   },
-  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3-8b": {
   "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+    },
   },
   },
   "llama-3-70b": {
   "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
-  },
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
+    },
+  },
+  "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
+  "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
   ### mistral
   ### mistral
-  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
-  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
+  "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
+  "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
   ### deepseek
   ### deepseek
-  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
-  "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),},
+  "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
+  "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
   ### llava
   ### llava
-  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
+  "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
   ### qwen
-  "qwen-2.5-coder-1.5b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-coder-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-math-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-14b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),
-  },
-  "qwen-2.5-72b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "qwen-2.5-math-72b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
+  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
+  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
+  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
+  "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
+  "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
+  "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
+  "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
+  "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
   ### nemotron
   ### nemotron
-  "nemotron-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "nemotron-70b-bf16": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
+  "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
+  # gemma
+  "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
+  "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
   # dummy
   # dummy
-  "dummy": {
-    "DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),
-  },
+  "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
+}
+
+pretty_name = {
+  "llama-3.2-1b": "Llama 3.2 1B",
+  "llama-3.2-3b": "Llama 3.2 3B",
+  "llama-3.1-8b": "Llama 3.1 8B",
+  "llama-3.1-70b": "Llama 3.1 70B",
+  "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
+  "llama-3.1-405b": "Llama 3.1 405B",
+  "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
+  "gemma2-9b": "Gemma2 9B",
+  "gemma2-27b": "Gemma2 27B",
+  "nemotron-70b": "Nemotron 70B",
+  "nemotron-70b-bf16": "Nemotron 70B (BF16)",
+  "mistral-nemo": "Mistral Nemo",
+  "mistral-large": "Mistral Large",
+  "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
+  "deepseek-coder-v2.5": "Deepseek Coder V2.5",
+  "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
+  "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
+  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
+  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
+  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
+  "qwen-2.5-7b": "Qwen 2.5 7B",
+  "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
+  "qwen-2.5-14b": "Qwen 2.5 14B",
+  "qwen-2.5-72b": "Qwen 2.5 72B",
+  "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "llama-3-8b": "Llama 3 8B",
+  "llama-3-70b": "Llama 3 70B",
 }
 }
+
+def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
+  return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
+
+def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
+  repo = get_repo(model_id, inference_engine_classname)
+  n_layers = model_cards.get(model_id, {}).get("layers", 0)
+  if repo is None or n_layers < 1:
+    return None
+  return Shard(model_id, 0, 0, n_layers)
+
+def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+  if not supported_inference_engine_lists:
+    return list(model_cards.keys())
+
+  from exo.inference.inference_engine import inference_engine_classes
+  supported_inference_engine_lists = [
+    [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
+    for engine_list in supported_inference_engine_lists
+  ]
+
+  def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
+    return any(engine in model_info.get("repo", {}) for engine in engine_list)
+
+  def supports_all_engine_lists(model_info: dict) -> bool:
+    return all(has_any_engine(model_info, engine_list)
+              for engine_list in supported_inference_engine_lists)
+
+  return [
+    model_id for model_id, model_info in model_cards.items()
+    if supports_all_engine_lists(model_info)
+  ]

+ 11 - 8
exo/networking/grpc/grpc_peer_handle.py

@@ -9,7 +9,7 @@ from . import node_service_pb2_grpc
 from ..peer_handle import PeerHandle
 from ..peer_handle import PeerHandle
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
-from exo.topology.device_capabilities import DeviceCapabilities
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
 
 
 
 
@@ -32,7 +32,11 @@ class GRPCPeerHandle(PeerHandle):
 
 
   async def connect(self):
   async def connect(self):
     if self.channel is None:
     if self.channel is None:
-      self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
+      self.channel = grpc.aio.insecure_channel(self.address, options=[
+        ("grpc.max_metadata_size", 32*1024*1024),
+        ('grpc.max_receive_message_length', 32*1024*1024),
+        ('grpc.max_send_message_length', 32*1024*1024)
+      ])
       self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
       self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     await self.channel.channel_ready()
     await self.channel.channel_ready()
 
 
@@ -63,10 +67,9 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
         traceback.print_exc()
       return False
       return False
 
 
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       prompt=prompt,
-      image_str=image_str,
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         model_id=shard.model_id,
         start_layer=shard.start_layer,
         start_layer=shard.start_layer,
@@ -74,7 +77,6 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
         n_layers=shard.n_layers,
       ),
       ),
       request_id=request_id,
       request_id=request_id,
-      inference_state=inference_state,
     )
     )
     response = await self.stub.SendPrompt(request)
     response = await self.stub.SendPrompt(request)
 
 
@@ -83,7 +85,7 @@ class GRPCPeerHandle(PeerHandle):
 
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
 
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         model_id=shard.model_id,
@@ -93,7 +95,6 @@ class GRPCPeerHandle(PeerHandle):
       ),
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
       request_id=request_id,
-      inference_state=inference_state,
     )
     )
     response = await self.stub.SendTensor(request)
     response = await self.stub.SendTensor(request)
 
 
@@ -117,7 +118,9 @@ class GRPCPeerHandle(PeerHandle):
     response = await self.stub.CollectTopology(request)
     response = await self.stub.CollectTopology(request)
     topology = Topology()
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
     for node_id, capabilities in response.nodes.items():
-      device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
+      device_capabilities = DeviceCapabilities(
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+      )
       topology.update_node(node_id, device_capabilities)
       topology.update_node(node_id, device_capabilities)
     for node_id, peers in response.peer_graph.items():
     for node_id, peers in response.peer_graph.items():
       for peer_id in peers.peer_ids:
       for peer_id in peers.peer_ids:

+ 3 - 5
exo/networking/grpc/grpc_server.py

@@ -49,10 +49,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       n_layers=request.shard.n_layers,
       n_layers=request.shard.n_layers,
     )
     )
     prompt = request.prompt
     prompt = request.prompt
-    image_str = request.image_str
     request_id = request.request_id
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    result = await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
 
@@ -65,9 +64,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     )
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
     request_id = request.request_id
-    inference_state = request.inference_state
 
 
-    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
+    result = await self.node.process_tensor(shard, tensor, request_id)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

+ 2 - 5
exo/networking/grpc/node_service.proto

@@ -22,16 +22,13 @@ message Shard {
 message PromptRequest {
 message PromptRequest {
   Shard shard = 1;
   Shard shard = 1;
   string prompt = 2;
   string prompt = 2;
-  optional string image_str = 3;
-  optional string request_id = 4;
-  optional string inference_state = 5;
+  optional string request_id = 3;
 }
 }
 
 
 message TensorRequest {
 message TensorRequest {
   Shard shard = 1;
   Shard shard = 1;
   Tensor tensor = 2;
   Tensor tensor = 2;
   optional string request_id = 3;
   optional string request_id = 3;
-  optional string inference_state = 4;
 }
 }
 
 
 message GetInferenceResultRequest {
 message GetInferenceResultRequest {
@@ -93,4 +90,4 @@ message HealthCheckResponse {
   bool is_healthy = 1;
   bool is_healthy = 1;
 }
 }
 
 
-message Empty {}
+message Empty {}

File diff suppressed because it is too large
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 9 - 7
exo/networking/manual/manual_discovery.py

@@ -19,7 +19,9 @@ class ManualDiscovery(Discovery):
     self.create_peer_handle = create_peer_handle
     self.create_peer_handle = create_peer_handle
 
 
     if node_id not in self.topology.peers:
     if node_id not in self.topology.peers:
-      raise ValueError(f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}")
+      raise ValueError(
+        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
+      )
 
 
     self.listen_task = None
     self.listen_task = None
 
 
@@ -42,7 +44,6 @@ class ManualDiscovery(Discovery):
     if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
     if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
     return list(self.known_peers.values())
     return list(self.known_peers.values())
 
 
-
   async def task_find_peers_from_config(self):
   async def task_find_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     while True:
     while True:
@@ -52,18 +53,19 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           peer = self.known_peers.get(peer_id)
           if not peer:
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)  
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           is_healthy = await peer.health_check()
           if is_healthy:
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
             self.known_peers[peer_id] = peer
             self.known_peers[peer_id] = peer
           else:
           else:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try: del self.known_peers[peer_id]
-            except KeyError: pass
+            try:
+              del self.known_peers[peer_id]
+            except KeyError:
+              pass
         except Exception as e:
         except Exception as e:
-            if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
       await asyncio.sleep(1.0)
       await asyncio.sleep(1.0)
 
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
-

+ 0 - 1
exo/networking/manual/network_topology_config.py

@@ -17,7 +17,6 @@ class NetworkTopology(BaseModel):
   """
   """
   node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
   node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
   """
   """
-
   @classmethod
   @classmethod
   def from_path(cls, path: str) -> "NetworkTopology":
   def from_path(cls, path: str) -> "NetworkTopology":
     try:
     try:

+ 3 - 2
exo/networking/peer_handle.py

@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 
 
+
 class PeerHandle(ABC):
 class PeerHandle(ABC):
   @abstractmethod
   @abstractmethod
   def id(self) -> str:
   def id(self) -> str:
@@ -35,11 +36,11 @@ class PeerHandle(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod

+ 11 - 11
exo/networking/tailscale/tailscale_discovery.py

@@ -8,6 +8,7 @@ from exo.topology.device_capabilities import DeviceCapabilities, device_capabili
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
 from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
 from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
 
 
+
 class TailscaleDiscovery(Discovery):
 class TailscaleDiscovery(Discovery):
   def __init__(
   def __init__(
     self,
     self,
@@ -69,14 +70,11 @@ class TailscaleDiscovery(Discovery):
         devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
         devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
         current_time = time.time()
         current_time = time.time()
 
 
-        active_devices = {
-          name: device for name, device in devices.items()
-          if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30
-        }
+        active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
 
 
         if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
         if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
         if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
         if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
-        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time  - device.last_seen.timestamp()) for device in devices.values()])
+        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time - device.last_seen.timestamp()) for device in devices.values()])
 
 
         for device in active_devices.values():
         for device in active_devices.values():
           if device.name == self.node_id: continue
           if device.name == self.node_id: continue
@@ -141,7 +139,13 @@ class TailscaleDiscovery(Discovery):
         for peer_id, should_remove in zip(peer_ids, results):
         for peer_id, should_remove in zip(peer_ids, results):
           if should_remove: peers_to_remove.append(peer_id)
           if should_remove: peers_to_remove.append(peer_id)
 
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}" for peer_handle, connected_at, last_seen in self.known_peers.values() })
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}"
+              for peer_handle, connected_at, last_seen in self.known_peers.values()
+            }
+          )
 
 
         for peer_id in peers_to_remove:
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers:
           if peer_id in self.known_peers:
@@ -164,9 +168,5 @@ class TailscaleDiscovery(Discovery):
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       return True
       return True
 
 
-    should_remove = (
-      (not is_connected and current_time - connected_at > self.discovery_timeout) or
-      (current_time - last_seen > self.discovery_timeout) or
-      (not health_ok)
-    )
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
     return should_remove
     return should_remove

+ 15 - 26
exo/networking/tailscale/tailscale_helpers.py

@@ -7,6 +7,7 @@ from exo.helpers import DEBUG_DISCOVERY
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from datetime import datetime, timezone
 from datetime import datetime, timezone
 
 
+
 class Device:
 class Device:
   def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
   def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
     self.device_id = device_id
     self.device_id = device_id
@@ -16,12 +17,7 @@ class Device:
 
 
   @classmethod
   @classmethod
   def from_dict(cls, data: Dict[str, Any]) -> 'Device':
   def from_dict(cls, data: Dict[str, Any]) -> 'Device':
-    return cls(
-      device_id=data.get('id', ''),
-      name=data.get('name', ''),
-      addresses=data.get('addresses', []),
-      last_seen=cls.parse_datetime(data.get('lastSeen'))
-    )
+    return cls(device_id=data.get('id', ''), name=data.get('name', ''), addresses=data.get('addresses', []), last_seen=cls.parse_datetime(data.get('lastSeen')))
 
 
   @staticmethod
   @staticmethod
   def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
   def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
@@ -29,13 +25,10 @@ class Device:
       return None
       return None
     return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
     return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
 
 
+
 async def get_device_id() -> str:
 async def get_device_id() -> str:
   try:
   try:
-    process = await asyncio.create_subprocess_exec(
-      'tailscale', 'status', '--json',
-      stdout=asyncio.subprocess.PIPE,
-      stderr=asyncio.subprocess.PIPE
-    )
+    process = await asyncio.create_subprocess_exec('tailscale', 'status', '--json', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
     stdout, stderr = await process.communicate()
     stdout, stderr = await process.communicate()
     if process.returncode != 0:
     if process.returncode != 0:
       raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
       raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
@@ -45,22 +38,16 @@ async def get_device_id() -> str:
   except Exception as e:
   except Exception as e:
     raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
     raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
 
 
+
 async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
 async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
   async with aiohttp.ClientSession() as session:
   async with aiohttp.ClientSession() as session:
     base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
     base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}',
-      'Content-Type': 'application/json'
-    }
+    headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
 
 
     attributes = {
     attributes = {
-      "custom:exo_node_id": node_id.replace('-', '_'),
-      "custom:exo_node_port": node_port,
-      "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
-      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model),
-      "custom:exo_device_capability_memory": str(device_capabilities.memory),
-      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16),
-      "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
+      "custom:exo_node_id": node_id.replace('-', '_'), "custom:exo_node_port": node_port, "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
+      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model), "custom:exo_device_capability_memory": str(device_capabilities.memory),
+      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16), "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
       "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
       "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
     }
     }
 
 
@@ -73,12 +60,11 @@ async def update_device_attributes(device_id: str, api_key: str, node_id: str, n
         else:
         else:
           print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
           print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
 
 
+
 async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
 async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
   async with aiohttp.ClientSession() as session:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
     url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}'
-    }
+    headers = {'Authorization': f'Bearer {api_key}'}
     async with session.get(url, headers=headers) as response:
     async with session.get(url, headers=headers) as response:
       if response.status == 200:
       if response.status == 200:
         data = await response.json()
         data = await response.json()
@@ -100,6 +86,7 @@ async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int,
         print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
         print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
         return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
         return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
 
 
+
 def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
 def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
   result = {}
   result = {}
   prefix = "custom:exo_"
   prefix = "custom:exo_"
@@ -112,12 +99,14 @@ def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
         result[attr_name] = float(value)
         result[attr_name] = float(value)
   return result
   return result
 
 
+
 def sanitize_attribute(value: str) -> str:
 def sanitize_attribute(value: str) -> str:
   # Replace invalid characters with underscores
   # Replace invalid characters with underscores
   sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
   sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
   # Truncate to 50 characters
   # Truncate to 50 characters
   return sanitized_value[:50]
   return sanitized_value[:50]
 
 
+
 async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
 async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
   async with aiohttp.ClientSession() as session:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
     url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
@@ -133,4 +122,4 @@ async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]
         device = Device.from_dict(device_data)
         device = Device.from_dict(device_data)
         devices[device.name] = device
         devices[device.name] = device
 
 
-      return devices
+      return devices

+ 2 - 0
exo/networking/tailscale/test_tailscale_discovery.py

@@ -5,6 +5,7 @@ from unittest import mock
 from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.peer_handle import PeerHandle
 from exo.networking.peer_handle import PeerHandle
 
 
+
 class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
 class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
   async def asyncSetUp(self):
     self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
     self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
@@ -37,5 +38,6 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     # Check if discovered peers are instances of GRPCPeerHandle
     # Check if discovered peers are instances of GRPCPeerHandle
     print(peers)
     print(peers)
 
 
+
 if __name__ == '__main__':
 if __name__ == '__main__':
   unittest.main()
   unittest.main()

+ 14 - 15
exo/networking/udp/udp_discovery.py

@@ -9,6 +9,7 @@ from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
 from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
 
 
+
 class ListenProtocol(asyncio.DatagramProtocol):
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
     super().__init__()
@@ -90,17 +91,13 @@ class UDPDiscovery(Discovery):
           "node_id": self.node_id,
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1, # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
         })
         })
         if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
         if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
 
 
         transport = None
         transport = None
         try:
         try:
-          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-            lambda: BroadcastProtocol(message, self.broadcast_port),
-            local_addr=(addr, 0),
-            family=socket.AF_INET
-          )
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
           if DEBUG_DISCOVERY >= 3:
           if DEBUG_DISCOVERY >= 3:
             print(f"Broadcasting presence at ({addr})")
             print(f"Broadcasting presence at ({addr})")
         except Exception as e:
         except Exception as e:
@@ -145,7 +142,8 @@ class UDPDiscovery(Discovery):
         if peer_id in self.known_peers:
         if peer_id in self.known_peers:
           existing_peer_prio = self.known_peers[peer_id][3]
           existing_peer_prio = self.known_peers[peer_id][3]
           if existing_peer_prio >= peer_prio:
           if existing_peer_prio >= peer_prio:
-            if DEBUG >= 1: print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
+            if DEBUG >= 1:
+              print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
             return
             return
         new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
         new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
         if not await new_peer_handle.health_check():
         if not await new_peer_handle.health_check():
@@ -161,8 +159,7 @@ class UDPDiscovery(Discovery):
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
 
 
   async def task_listen_for_peers(self):
   async def task_listen_for_peers(self):
-    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
-                                                            local_addr=("0.0.0.0", self.listen_port))
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
     if DEBUG_DISCOVERY >= 2: print("Started listen task")
     if DEBUG_DISCOVERY >= 2: print("Started listen task")
 
 
   async def task_cleanup_peers(self):
   async def task_cleanup_peers(self):
@@ -177,7 +174,13 @@ class UDPDiscovery(Discovery):
         for peer_id, should_remove in zip(peer_ids, results):
         for peer_id, should_remove in zip(peer_ids, results):
           if should_remove: peers_to_remove.append(peer_id)
           if should_remove: peers_to_remove.append(peer_id)
 
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}" for peer_handle, connected_at, last_seen, prio in self.known_peers.values() })
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}"
+              for peer_handle, connected_at, last_seen, prio in self.known_peers.values()
+            }
+          )
 
 
         for peer_id in peers_to_remove:
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers:
           if peer_id in self.known_peers:
@@ -200,9 +203,5 @@ class UDPDiscovery(Discovery):
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       return True
       return True
 
 
-    should_remove = (
-      (not is_connected and current_time - connected_at > self.discovery_timeout) or
-      (current_time - last_seen > self.discovery_timeout) or
-      (not health_ok)
-    )
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
     return should_remove
     return should_remove

+ 2 - 2
exo/orchestration/node.py

@@ -16,11 +16,11 @@ class Node(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod

+ 112 - 122
exo/orchestration/standard_node.py

@@ -39,6 +39,8 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.buffered_logits: Dict[str, List[np.ndarray]] = {}
+    self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
@@ -87,24 +89,53 @@ class StandardNode(Node):
   def get_supported_inference_engines(self):
   def get_supported_inference_engines(self):
     supported_engine_names = []
     supported_engine_names = []
     if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
     if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
-        supported_engine_names.append('mlx')
-        supported_engine_names.append('tinygrad')
+      supported_engine_names.append('mlx')
+      supported_engine_names.append('tinygrad')
     else:
     else:
-        supported_engine_names.append('tinygrad')
+      supported_engine_names.append('tinygrad')
     return supported_engine_names
     return supported_engine_names
 
 
   async def broadcast_supported_engines(self, supported_engines_names: List[str]):
   async def broadcast_supported_engines(self, supported_engines_names: List[str]):
-    status_message = json.dumps({
-        "type": "supported_inference_engines",
-        "node_id": self.id,
-        "engines": supported_engines_names
-    })
+    status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
     await self.broadcast_opaque_status("", status_message)
     await self.broadcast_opaque_status("", status_message)
 
 
   def get_topology_inference_engines(self) -> List[List[str]]:
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
     return self.topology_inference_engines_pool
+  
+  async def process_inference_result(
+    self,
+    shard,
+    result: np.ndarray,
+    request_id: Optional[str] = None,
+  ):
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if shard.is_last_layer() and not is_finished:
+      token = await self.inference_engine.sample(result)
+      await self.inference_engine.ensure_shard(shard)
+      self.buffered_token_output[request_id][0].append(token.item())
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+      forward = token.reshape(1, -1)
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
+    else:
+      forward = result
 
 
-  async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+    if is_finished:
+      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+    else:
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+
+    return np.array(self.buffered_token_output[request_id][0])
+
+  async def process_prompt(
+    self,
+    base_shard: Shard,
+    prompt: str,
+    request_id: Optional[str] = None,
+  ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
@@ -116,14 +147,12 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
           "prompt": prompt,
-          "image_str": image_str,
-          "inference_state": inference_state,
           "request_id": request_id,
           "request_id": request_id,
         }),
         }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, request_id)
     end_time = time.perf_counter_ns()
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
     asyncio.create_task(
@@ -136,8 +165,6 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
           "prompt": prompt,
-          "image_str": image_str,
-          "inference_state": inference_state,
           "request_id": request_id,
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
           "elapsed_time_ns": elapsed_time_ns,
           "result_size": resp.size if resp is not None else 0,
           "result_size": resp.size if resp is not None else 0,
@@ -146,42 +173,26 @@ class StandardNode(Node):
     )
     )
     return resp
     return resp
 
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
     if request_id is None:
       request_id = str(uuid.uuid4())
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
 
 
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
-    if shard.start_layer != 0:
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
-      await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
-      return
-
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
-    is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:
-      self.buffered_token_output[request_id][0].append(result.item())
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-    if not is_finished:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
-
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+    if not shard.is_first_layer():
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      resp = await self.forward_prompt(shard, prompt, request_id, 0)
+      return None
+    else:
+      result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
+      ret = await self.process_inference_result(shard, result, request_id) 
+      return result
 
 
   async def process_tensor(
   async def process_tensor(
     self,
     self,
     base_shard: Shard,
     base_shard: Shard,
     tensor: np.ndarray,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ) -> Optional[np.ndarray]:
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(
@@ -196,12 +207,11 @@ class StandardNode(Node):
           "tensor_size": tensor.size,
           "tensor_size": tensor.size,
           "tensor_shape": tensor.shape,
           "tensor_shape": tensor.shape,
           "request_id": request_id,
           "request_id": request_id,
-          "inference_state": inference_state,
         }),
         }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
-    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
+    resp = await self._process_tensor(shard, tensor, request_id)
     end_time = time.perf_counter_ns()
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
     asyncio.create_task(
@@ -226,84 +236,77 @@ class StandardNode(Node):
     base_shard: Shard,
     base_shard: Shard,
     tensor: np.ndarray,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ) -> Optional[np.ndarray]:
   ) -> Optional[np.ndarray]:
     if request_id is None:
     if request_id is None:
       request_id = str(uuid.uuid4())
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
 
 
+    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
     try:
-      if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-      result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-      is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if is_finished:
-        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-      if result.size == 1:  # we got a new token out
-        self.buffered_token_output[request_id][0].append(result.item())
-        self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-      if not is_finished:
-        asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
+      ret = await self.process_inference_result(shard, result, request_id) 
+      return ret
     except Exception as e:
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       traceback.print_exc()
       return None
       return None
 
 
-  async def forward_to_next_shard(
+  async def forward_prompt(
     self,
     self,
     base_shard: Shard,
     base_shard: Shard,
-    tensor_or_prompt: Union[np.ndarray, str],
+    prompt: str,
     request_id: str,
     request_id: str,
-    image_str: Optional[str] = None,
-    inference_state: Optional[str] = None,
+    target_index: int,
   ) -> None:
   ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_prompt(next_shard, prompt, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+  
+  async def forward_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: str,
+    target_index: int,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_tensor(next_shard, tensor, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
+
+  def get_partition_index(self, offset: int = 0):
     if not self.partitioning_strategy:
     if not self.partitioning_strategy:
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-      return
-    shard = self.get_current_shard(base_shard)
-
+      return None
     partitions = self.partitioning_strategy.partition(self.topology)
     partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-    if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
-    if current_partition_index is not None:
-      next_partition_index = (current_partition_index+1) % len(partitions)
-      next_partition: Partition = partitions[next_partition_index]
-      next_shard = shards[next_partition_index]
-      if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
-
-      if next_partition.node_id == self.id:
-        if isinstance(tensor_or_prompt, np.ndarray):
-          await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-        else:
-          await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
-        return
-
-      target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {next_partition} not found")
-
-      if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
-
-      if isinstance(tensor_or_prompt, np.ndarray):
-        await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
-      else:
-        await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
+    if current_partition_index is None:
+      raise ValueError(f"No current partition found for node: {self.id}")
+    return (current_partition_index + offset) % len(partitions)
 
 
-  def get_current_shard(self, base_shard: Shard) -> Shard:
+  def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
+    if index is None:
+      index = self.get_partition_index()
     partitions = self.partitioning_strategy.partition(self.topology)
     partitions = self.partitioning_strategy.partition(self.topology)
     shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
     shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
-    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-    if current_partition_index is None:
-      raise ValueError(f"No current partition found for node: {self.id}")
-    return shards[current_partition_index]
+    return shards[index]
 
 
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
     next_peers = await self.discovery.discover_peers(wait_for_peers)
     next_peers = await self.discovery.discover_peers(wait_for_peers)
@@ -311,20 +314,16 @@ class StandardNode(Node):
     next_peer_ids = {peer.id() for peer in next_peers}
     next_peer_ids = {peer.id() for peer in next_peers}
     peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
     peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
     peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
     peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
-    peers_updated = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
-    peers_unchanged = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
+    peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
+    peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
     peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
     peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
     peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
     peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
 
 
     def _pretty(peers: List[PeerHandle]) -> List[str]:
     def _pretty(peers: List[PeerHandle]) -> List[str]:
       return [f"{peer.id()}@{peer.addr()}" for peer in peers]
       return [f"{peer.id()}@{peer.addr()}" for peer in peers]
-    if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
+
+    if DEBUG >= 2:
+      print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
 
 
     async def disconnect_with_timeout(peer, timeout=5):
     async def disconnect_with_timeout(peer, timeout=5):
       try:
       try:
@@ -344,14 +343,8 @@ class StandardNode(Node):
         traceback.print_exc()
         traceback.print_exc()
         return False
         return False
 
 
-    disconnect_results = await asyncio.gather(
-      *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
-      return_exceptions=True
-    )
-    connect_results = await asyncio.gather(
-      *(connect_with_timeout(peer) for peer in peers_to_connect),
-      return_exceptions=True
-    )
+    disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
+    connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
 
 
     successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
     successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
     failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
     failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
@@ -370,12 +363,7 @@ class StandardNode(Node):
     supported_engines = self.get_supported_inference_engines()
     supported_engines = self.get_supported_inference_engines()
     await self.broadcast_supported_engines(supported_engines)
     await self.broadcast_supported_engines(supported_engines)
     if len(self.get_topology_inference_engines()):
     if len(self.get_topology_inference_engines()):
-      if any(len(engines) == 1 and "tinygrad" in engines for engines in self.get_topology_inference_engines()):
-        if DEBUG >= 1: print("Found node with only tinygrad, using tinygrad on all nodes")
-        self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
-      else:
-        if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
-        self.inference_engine = get_inference_engine("mlx", self.shard_downloader) 
+      self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
 
 
   async def periodic_topology_collection(self, interval: int):
   async def periodic_topology_collection(self, interval: int):
     while True:
     while True:
@@ -422,6 +410,7 @@ class StandardNode(Node):
         self.topology.merge(other_topology)
         self.topology.merge(other_topology)
       except Exception as e:
       except Exception as e:
         print(f"Error collecting topology from {peer.id()}: {e}")
         print(f"Error collecting topology from {peer.id()}: {e}")
+        traceback.print_exc()
 
 
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
     self.topology = next_topology
@@ -440,7 +429,7 @@ class StandardNode(Node):
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     self.on_token.trigger_all(request_id, tokens, is_finished)
     self.on_token.trigger_all(request_id, tokens, is_finished)
-
+  
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     async def send_result_to_peer(peer):
     async def send_result_to_peer(peer):
       try:
       try:
@@ -464,6 +453,7 @@ class StandardNode(Node):
       except Exception as e:
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
         print(f"Error sending opaque status to {peer.id()}: {e}")
         traceback.print_exc()
         traceback.print_exc()
+
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     # in the case of opaque status, we also want to receive our own opaque statuses
     # in the case of opaque status, we also want to receive our own opaque statuses
     self.on_opaque_status.trigger_all(request_id, status)
     self.on_opaque_status.trigger_all(request_id, status)

+ 129 - 62
exo/tinychat/index.css

@@ -1,31 +1,11 @@
 /* define colors */
 /* define colors */
 :root {
 :root {
-  --primary-color: #a52e4d;
-  --primary-color-transparent: #a52e4d66;
-  --secondary-color: #228039;
-  --secondary-color-transparent: #22803966;
-
+  --primary-color: #fff;
+  --secondary-color: #2a2a2a;
+  --secondary-color-transparent: #ffffff66;
+  --primary-bg-color: #1a1a1a;
+  --foreground-color: #f0f0f0;
   --red-color: #a52e4d;
   --red-color: #a52e4d;
-  --green-color: #228039;
-  --silver-color: #88808e;
-}
-@media(prefers-color-scheme: light) {
-  :root {
-    --primary-bg-color: #f0f0f0;
-    --secondary-bg-color: #eeeeee;
-    --tertiary-bg-color: #dddddd;
-    --foreground-color: #111111;
-    --accent-color: #000000;
-  }
-}
-@media(prefers-color-scheme: dark) {
-  :root {
-    --primary-bg-color: #111111;
-    --secondary-bg-color: #131313;
-    --tertiary-bg-color: #232323;
-    --foreground-color: #f0f0f0;
-    --accent-color: #aaaaaa;
-  }
 }
 }
 
 
 main {
 main {
@@ -81,7 +61,11 @@ main {
   top: 0;
   top: 0;
   position: absolute;
   position: absolute;
 
 
-  background: linear-gradient(180deg, var(--primary-bg-color) 0%, transparent 100%);
+  background: linear-gradient(
+    180deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
 }
 }
 .histories-end {
 .histories-end {
   height: 3rem;
   height: 3rem;
@@ -91,7 +75,11 @@ main {
   bottom: 0;
   bottom: 0;
   position: absolute;
   position: absolute;
 
 
-  background: linear-gradient(0deg, var(--primary-bg-color) 0%, transparent 100%);
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
 }
 }
 
 
 .history {
 .history {
@@ -99,7 +87,7 @@ main {
   width: 100%;
   width: 100%;
   max-width: 40rem;
   max-width: 40rem;
 
 
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   border-radius: 10px;
   border-radius: 10px;
   border-left: 2px solid var(--primary-color);
   border-left: 2px solid var(--primary-color);
 
 
@@ -109,7 +97,7 @@ main {
   opacity: var(--opacity, 1);
   opacity: var(--opacity, 1);
 }
 }
 .history:hover {
 .history:hover {
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
 }
 }
 
 
 .history-delete-button {
 .history-delete-button {
@@ -120,14 +108,14 @@ main {
   margin: 0;
   margin: 0;
   outline: none;
   outline: none;
   border: none;
   border: none;
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   color: var(--foreground-color);
   border-radius: 0 0 0 10px;
   border-radius: 0 0 0 10px;
   cursor: pointer;
   cursor: pointer;
   transition: 0.2s;
   transition: 0.2s;
 }
 }
 .history-delete-button:hover {
 .history-delete-button:hover {
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   padding: 0.75rem;
   padding: 0.75rem;
 }
 }
 
 
@@ -135,6 +123,7 @@ main {
   overflow-y: auto;
   overflow-y: auto;
   height: 100%;
   height: 100%;
   width: 100%;
   width: 100%;
+  max-width: 1200px;
 
 
   display: flex;
   display: flex;
   flex-direction: column;
   flex-direction: column;
@@ -145,28 +134,25 @@ main {
 }
 }
 
 
 .message {
 .message {
-  width: 96%;
-  max-width: 80rem;
-
-  display: grid;
-
-  background-color: var(--secondary-bg-color);
+  max-width: 75%;
   padding: 0.5rem 1rem;
   padding: 0.5rem 1rem;
-  border-radius: 10px;
+  border-radius: 20px;
 }
 }
 .message-role-assistant {
 .message-role-assistant {
-  border-bottom: 2px solid var(--primary-color);
-  border-left: 2px solid var(--primary-color);
-  box-shadow: -10px 10px 20px 2px var(--primary-color-transparent);
+  background-color: var(--secondary-color);
+  margin-right: auto;
+  color: #fff;
 }
 }
 .message-role-user {
 .message-role-user {
-  border-bottom: 2px solid var(--secondary-color);
-  border-right: 2px solid var(--secondary-color);
-  box-shadow: 10px 10px 20px 2px var(--secondary-color-transparent);
+  margin-left: auto;
+  background-color: var(--primary-color);
+  color: #000;
 }
 }
 .download-progress {
 .download-progress {
   margin-bottom: 12em;
   margin-bottom: 12em;
   overflow-y: auto;
   overflow-y: auto;
+  min-height: 350px;
+  padding: 2rem;
 }
 }
 .message > pre {
 .message > pre {
   white-space: pre-wrap;
   white-space: pre-wrap;
@@ -191,17 +177,70 @@ main {
 }
 }
 
 
 .toast {
 .toast {
-    width: 100%; /* Take up the full width of the page */
-    background-color: #fc2a2a; /* Dark background color */
-    color: #fff; /* White text color */
-    text-align: center; /* Centered text */
-    border-radius: 2px; /* Rounded borders */
-    padding: 16px; /* Padding */
-    position: fixed; /* Sit on top of the screen content */
-    z-index: 9999; /* Add a z-index if needed */
-    top: 0; /* Position at the top of the page */
-    left: 0; /* Extend from the left edge */
-    right: 0; /* Extend to the right edge */
+    width: 100%;
+    background-color: #fc2a2a;
+    color: #fff;
+    text-align: left;
+    border-radius: 2px;
+    padding: 16px;
+    position: fixed;
+    z-index: 9999;
+    top: 0;
+    left: 0;
+    right: 0;
+    display: flex;
+    flex-direction: column;
+    white-space: pre-wrap;
+    font-family: monospace;
+}
+
+.toast-header {
+    display: flex;
+    justify-content: space-between;
+    align-items: center;
+    width: 100%;
+}
+
+.toast-error-message {
+    flex-grow: 1;
+}
+
+.toast-header-buttons {
+    display: flex;
+    align-items: center;
+    gap: 16px;
+    margin-left: 24px;
+}
+
+.toast-expand-button {
+    background: none;
+    border: none;
+    color: white;
+    padding: 4px;
+    cursor: pointer;
+    font-size: 1em;
+}
+
+.toast-close-button {
+    background: none;
+    border: none;
+    color: white;
+    padding: 4px;
+    cursor: pointer;
+    font-size: 1.2em;
+    line-height: 1;
+}
+
+.toast-expand-button:hover,
+.toast-close-button:hover {
+    opacity: 0.8;
+}
+
+.toast-content {
+    margin-top: 10px;
+    padding: 10px;
+    background-color: rgba(0, 0, 0, 0.2);
+    border-radius: 4px;
 }
 }
 
 
 .hljs {
 .hljs {
@@ -220,14 +259,14 @@ main {
   margin: 0;
   margin: 0;
   outline: none;
   outline: none;
   border: none;
   border: none;
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   color: var(--foreground-color);
   border-radius: 0 0 0 10px;
   border-radius: 0 0 0 10px;
   cursor: pointer;
   cursor: pointer;
   transition: 0.2s;
   transition: 0.2s;
 }
 }
 .clipboard-button:hover {
 .clipboard-button:hover {
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   padding: 0.75rem;
   padding: 0.75rem;
 }
 }
 
 
@@ -236,9 +275,14 @@ main {
   bottom: 0;
   bottom: 0;
 
 
   /* linear gradient from background-color to transparent on the top */
   /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(0deg, var(--primary-bg-color) 55%, transparent 100%);
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
 
 
   width: 100%;
   width: 100%;
+  max-width: 1200px;
   display: flex;
   display: flex;
   flex-direction: column;
   flex-direction: column;
   justify-content: center;
   justify-content: center;
@@ -285,7 +329,7 @@ main {
   min-height: 3rem;
   min-height: 3rem;
   max-height: 8rem;
   max-height: 8rem;
 
 
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   color: var(--foreground-color);
   border-radius: 10px;
   border-radius: 10px;
   border: none;
   border: none;
@@ -297,8 +341,8 @@ main {
   height: 3rem;
   height: 3rem;
   width: 4rem;
   width: 4rem;
 
 
-  background-color: var(--secondary-color);
-  color: var(--foreground-color);
+  background-color: var(--primary-color);
+  color: var(--secondary-color);
   border-radius: 10px;
   border-radius: 10px;
   padding: 0.5rem;
   padding: 0.5rem;
   cursor: pointer;
   cursor: pointer;
@@ -307,7 +351,7 @@ main {
   background-color: var(--secondary-color-transparent);
   background-color: var(--secondary-color-transparent);
 }
 }
 .input-button:disabled {
 .input-button:disabled {
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   cursor: not-allowed;
   cursor: not-allowed;
 }
 }
 
 
@@ -414,4 +458,27 @@ p {
   max-width: 100%;
   max-width: 100%;
   max-height: 100%;
   max-height: 100%;
   object-fit: contain;
   object-fit: contain;
+}
+
+.clear-history-button {
+  background-color: var(--red-color);
+  color: white;
+  padding: 10px 20px;
+  border-radius: 5px;
+  display: flex;
+  align-items: center;
+  gap: 8px;
+  transition: all 0.3s ease;
+  margin: 1rem auto;
+  border: none;
+  cursor: pointer;
+}
+
+.clear-history-button:hover {
+  opacity: 0.8;
+  transform: scale(1.05);
+}
+
+.clear-history-button i {
+  font-size: 14px;
 }
 }

+ 26 - 25
exo/tinychat/index.html

@@ -26,33 +26,27 @@
 <body>
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
 <main x-data="state" x-init="console.log(endpoint)">
      <!-- Error Toast -->
      <!-- Error Toast -->
-    <div x-show="errorMessage" x-transition.opacity x-text="errorMessage" class="toast">
+    <div x-show="errorMessage" x-transition.opacity class="toast">
+        <div class="toast-header">
+            <span class="toast-error-message" x-text="errorMessage.basic"></span>
+            <div class="toast-header-buttons">
+                <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
+                        class="toast-expand-button" 
+                        x-show="errorMessage.stack">
+                    <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
+                </button>
+                <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
+                    <i class="fas fa-times"></i>
+                </button>
+            </div>
+        </div>
+        <div class="toast-content" x-show="errorExpanded" x-transition>
+            <span x-text="errorMessage.stack"></span>
+        </div>
     </div>
     </div>
 <div class="model-selector">
 <div class="model-selector">
-<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel">
-<option selected="" value="llama-3.2-1b">Llama 3.2 1B</option>
-<option value="llama-3.2-3b">Llama 3.2 3B</option>
-<option value="llama-3.1-8b">Llama 3.1 8B</option>
-<option value="llama-3.1-70b">Llama 3.1 70B</option>
-<option value="llama-3.1-70b-bf16">Llama 3.1 70B (BF16)</option>
-<option value="llama-3.1-405b">Llama 3.1 405B</option>
-<option value="llama-3-8b">Llama 3 8B</option>
-<option value="llama-3-70b">Llama 3 70B</option>
-<option value="nemotron-70b">Nemotron 70B</option>
-<option value="nemotron-70b-bf16">Nemotron 70B (BF16)</option>
-<option value="mistral-nemo">Mistral Nemo</option>
-<option value="mistral-large">Mistral Large</option>
-<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
-<option value="deepseek-coder-v2.5">Deepseek Coder V2.5</option>
-<option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
-<option value="qwen-2.5-coder-1.5b">Qwen 2.5 Coder 1.5B</option>
-<option value="qwen-2.5-coder-7b">Qwen 2.5 Coder 7B</option>
-<option value="qwen-2.5-7b">Qwen 2.5 7B</option>
-<option value="qwen-2.5-math-7b">Qwen 2.5 7B (Math)</option>
-<option value="qwen-2.5-14b">Qwen 2.5 14B</option>
-<option value="qwen-2.5-72b">Qwen 2.5 72B</option>
-<option value="qwen-2.5-math-72b">Qwen 2.5 72B (Math)</option>
-</select>
+  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
+  </select>
 </div>
 </div>
 <div @popstate.window="
 <div @popstate.window="
       if (home === 2) {
       if (home === 2) {
@@ -68,6 +62,13 @@
       if (home === -1) setTimeout(() =&gt; home = 0, 100);
       if (home === -1) setTimeout(() =&gt; home = 0, 100);
     " x-show="home === 0" x-transition="">
     " x-show="home === 0" x-transition="">
 <h1 class="title megrim-regular">tinychat</h1>
 <h1 class="title megrim-regular">tinychat</h1>
+<template x-if="histories.length">
+  <button 
+    @click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();" 
+    class="clear-history-button">
+    <i class="fas fa-trash"></i> Clear All History
+  </button>
+</template>
 <div class="histories-container-container">
 <div class="histories-container-container">
 <template x-if="histories.length">
 <template x-if="histories.length">
 <div class="histories-start"></div>
 <div class="histories-start"></div>

+ 109 - 14
exo/tinychat/index.js

@@ -4,8 +4,8 @@ document.addEventListener("alpine:init", () => {
     cstate: {
     cstate: {
       time: null,
       time: null,
       messages: [],
       messages: [],
-      selectedModel: 'llama-3.1-8b',
-    },
+      selectedModel: 'llama-3.2-1b',
+    },    
 
 
     // historical state
     // historical state
     histories: JSON.parse(localStorage.getItem("histories")) || [],
     histories: JSON.parse(localStorage.getItem("histories")) || [],
@@ -14,6 +14,8 @@ document.addEventListener("alpine:init", () => {
     generating: false,
     generating: false,
     endpoint: `${window.location.origin}/v1`,
     endpoint: `${window.location.origin}/v1`,
     errorMessage: null,
     errorMessage: null,
+    errorExpanded: false,
+    errorTimeout: null,
 
 
     // performance tracking
     // performance tracking
     time_till_first: 0,
     time_till_first: 0,
@@ -47,6 +49,12 @@ document.addEventListener("alpine:init", () => {
         localStorage.setItem("histories", JSON.stringify(this.histories));
         localStorage.setItem("histories", JSON.stringify(this.histories));
       }
       }
     },
     },
+
+    clearAllHistory() {
+      this.histories = [];
+      localStorage.setItem("histories", JSON.stringify([]));
+    },
+
     // Utility functions
     // Utility functions
     formatBytes(bytes) {
     formatBytes(bytes) {
       if (bytes === 0) return '0 B';
       if (bytes === 0) return '0 B';
@@ -66,6 +74,56 @@ document.addEventListener("alpine:init", () => {
       return `${s}s`;
       return `${s}s`;
     },
     },
 
 
+    async populateSelector() {
+      try {
+        const response = await fetch(`${window.location.origin}/modelpool`);
+        const responseText = await response.text(); // Get raw response text first
+        
+        if (!response.ok) {
+          throw new Error(`HTTP error! status: ${response.status}`);
+        }
+        
+        // Try to parse the response text
+        let responseJson;
+        try {
+          responseJson = JSON.parse(responseText);
+        } catch (parseError) {
+          console.error('Failed to parse JSON:', parseError);
+          throw new Error(`Invalid JSON response: ${responseText}`);
+        }
+
+        const sel = document.querySelector(".model-select");
+        if (!sel) {
+          throw new Error("Could not find model selector element");
+        }
+
+        // Clear the current options and add new ones
+        sel.innerHTML = '';
+          
+        const modelDict = responseJson["model pool"];
+        if (!modelDict) {
+          throw new Error("Response missing 'model pool' property");
+        }
+
+        Object.entries(modelDict).forEach(([key, value]) => {
+          const opt = document.createElement("option");
+          opt.value = key;
+          opt.textContent = value;
+          sel.appendChild(opt);
+        });
+
+        // Set initial value to the first model
+        const firstKey = Object.keys(modelDict)[0];
+        if (firstKey) {
+          sel.value = firstKey;
+          this.cstate.selectedModel = firstKey;
+        }
+      } catch (error) {
+        console.error("Error populating model selector:", error);
+        this.errorMessage = `Failed to load models: ${error.message}`;
+      }
+    },
+
     async handleImageUpload(event) {
     async handleImageUpload(event) {
       const file = event.target.files[0];
       const file = event.target.files[0];
       if (file) {
       if (file) {
@@ -110,12 +168,30 @@ document.addEventListener("alpine:init", () => {
         localStorage.setItem("pendingMessage", value);
         localStorage.setItem("pendingMessage", value);
         this.processMessage(value);
         this.processMessage(value);
       } catch (error) {
       } catch (error) {
-        console.error('error', error)
-        this.lastErrorMessage = error.message || 'Unknown error on handleSend';
-        this.errorMessage = error.message || 'Unknown error on handleSend';
-        setTimeout(() => {
-          this.errorMessage = null;
-        }, 5 * 1000)
+        console.error('error', error);
+        const errorDetails = {
+            message: error.message || 'Unknown error',
+            stack: error.stack,
+            name: error.name || 'Error'
+        };
+        
+        this.errorMessage = {
+            basic: `${errorDetails.name}: ${errorDetails.message}`,
+            stack: errorDetails.stack
+        };
+
+        // Clear any existing timeout
+        if (this.errorTimeout) {
+            clearTimeout(this.errorTimeout);
+        }
+
+        // Only set the timeout if the error details aren't expanded
+        if (!this.errorExpanded) {
+            this.errorTimeout = setTimeout(() => {
+                this.errorMessage = null;
+                this.errorExpanded = false;
+            }, 30 * 1000);
+        }
         this.generating = false;
         this.generating = false;
       }
       }
     },
     },
@@ -232,12 +308,30 @@ document.addEventListener("alpine:init", () => {
           console.error("Failed to save histories to localStorage:", error);
           console.error("Failed to save histories to localStorage:", error);
         }
         }
       } catch (error) {
       } catch (error) {
-        console.error('error', error)
-        this.lastErrorMessage = error;
-        this.errorMessage = error;
-        setTimeout(() => {
-          this.errorMessage = null;
-        }, 5 * 1000)
+        console.error('error', error);
+        const errorDetails = {
+            message: error.message || 'Unknown error',
+            stack: error.stack,
+            name: error.name || 'Error'
+        };
+        
+        this.errorMessage = {
+            basic: `${errorDetails.name}: ${errorDetails.message}`,
+            stack: errorDetails.stack
+        };
+
+        // Clear any existing timeout
+        if (this.errorTimeout) {
+            clearTimeout(this.errorTimeout);
+        }
+
+        // Only set the timeout if the error details aren't expanded
+        if (!this.errorExpanded) {
+            this.errorTimeout = setTimeout(() => {
+                this.errorMessage = null;
+                this.errorExpanded = false;
+            }, 30 * 1000);
+        }
       } finally {
       } finally {
         this.generating = false;
         this.generating = false;
       }
       }
@@ -529,6 +623,7 @@ function createParser(onParse) {
     }
     }
   }
   }
 }
 }
+
 const BOM = [239, 187, 191];
 const BOM = [239, 187, 191];
 function hasBom(buffer) {
 function hasBom(buffer) {
   return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);
   return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);

+ 44 - 41
exo/tinychat/update_deps.py

@@ -4,49 +4,52 @@ from bs4 import BeautifulSoup
 from urllib.parse import urljoin, urlparse
 from urllib.parse import urljoin, urlparse
 import re
 import re
 
 
+
 def download_file(url, local_path):
 def download_file(url, local_path):
-    response = requests.get(url)
-    if response.status_code == 200:
-        os.makedirs(os.path.dirname(local_path), exist_ok=True)
-        with open(local_path, 'wb') as f:
-            f.write(response.content)
-        print(f"Downloaded: {local_path}")
-    else:
-        print(response.status_code)
-        print(f"Failed to download: {url}")
+  response = requests.get(url)
+  if response.status_code == 200:
+    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+    with open(local_path, 'wb') as f:
+      f.write(response.content)
+    print(f"Downloaded: {local_path}")
+  else:
+    print(response.status_code)
+    print(f"Failed to download: {url}")
+
 
 
 def update_html(html_content, base_url):
 def update_html(html_content, base_url):
-    soup = BeautifulSoup(html_content, 'html.parser')
+  soup = BeautifulSoup(html_content, 'html.parser')
 
 
-    for tag in soup.find_all(['script', 'link']):
-        if tag.has_attr('src'):
-            url = tag['src']
-        elif tag.has_attr('href'):
-            url = tag['href']
-        else:
-            continue
+  for tag in soup.find_all(['script', 'link']):
+    if tag.has_attr('src'):
+      url = tag['src']
+    elif tag.has_attr('href'):
+      url = tag['href']
+    else:
+      continue
+
+    if url.startswith(('http://', 'https://')):
+      full_url = url
+    else:
+      full_url = urljoin(base_url, url)
 
 
-        if url.startswith(('http://', 'https://')):
-            full_url = url
-        else:
-            full_url = urljoin(base_url, url)
+    parsed_url = urlparse(full_url)
+    local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
 
 
-        parsed_url = urlparse(full_url)
-        local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
+    download_file(full_url, local_path)
 
 
-        download_file(full_url, local_path)
+    relative_path = os.path.relpath(local_path, '.')
+    if tag.name == 'script':
+      tag['src'] = "/" + relative_path
+    elif tag.name == 'link':
+      tag['href'] = "/" + relative_path
 
 
-        relative_path = os.path.relpath(local_path, '.')
-        if tag.name == 'script':
-            tag['src'] = "/" + relative_path
-        elif tag.name == 'link':
-            tag['href'] = "/" + relative_path
+  return str(soup)
 
 
-    return str(soup)
 
 
 # Read the HTML file
 # Read the HTML file
 with open('./index.html', 'r') as f:
 with open('./index.html', 'r') as f:
-    html_content = f.read()
+  html_content = f.read()
 
 
 # Update HTML and download files
 # Update HTML and download files
 # updated_html = update_html(html_content, 'https://example.com')
 # updated_html = update_html(html_content, 'https://example.com')
@@ -68,7 +71,7 @@ download_file(css_url, css_output_path)
 
 
 # Parse CSS file for font URLs
 # Parse CSS file for font URLs
 with open(css_output_path, 'r', encoding='utf-8') as f:
 with open(css_output_path, 'r', encoding='utf-8') as f:
-    css_content = f.read()
+  css_content = f.read()
 
 
 # Extract font URLs from the CSS content
 # Extract font URLs from the CSS content
 font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
 font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
@@ -77,14 +80,14 @@ print(f"Found {len(font_urls)} font URLs")
 
 
 # Download font files
 # Download font files
 for font_url in font_urls:
 for font_url in font_urls:
-    font_url = font_url.strip('"\'')
-    if font_url.startswith('../'):
-        font_url = font_url[3:]
+  font_url = font_url.strip('"\'')
+  if font_url.startswith('../'):
+    font_url = font_url[3:]
 
 
-    # Use base_url instead of urljoin to keep the version number
-    full_url = base_url + font_url
-    relative_path = font_url
-    output_path = os.path.join(output_dir, relative_path)
-    download_file(full_url, output_path)
+  # Use base_url instead of urljoin to keep the version number
+  full_url = base_url + font_url
+  relative_path = font_url
+  output_path = os.path.join(output_dir, relative_path)
+  download_file(full_url, output_path)
 
 
-print("Download complete!")
+print("Download complete!")

+ 4 - 2
exo/topology/device_capabilities.py

@@ -52,9 +52,11 @@ CHIP_FLOPS = {
   "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
   "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
   "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
   "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
   "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
   "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
   "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
   "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
-  "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
+  "Apple M4": DeviceFlops(fp32=4.26*TFLOPS, fp16=8.52*TFLOPS, int8=17.04*TFLOPS),
+  "Apple M4 Pro": DeviceFlops(fp32=5.72*TFLOPS, fp16=11.44*TFLOPS, int8=22.88*TFLOPS),
+  "Apple M4 Max": DeviceFlops(fp32=18.03*TFLOPS, fp16=36.07*TFLOPS, int8=72.14*TFLOPS),
   ### A chips
   ### A chips
   "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
   "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
   "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
   "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),

+ 1 - 1
extra/start_openwebui.sh

@@ -1,3 +1,3 @@
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 echo "Using API_ENDPOINT=${API_ENDPOINT}"
 echo "Using API_ENDPOINT=${API_ENDPOINT}"
 docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
 docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main

+ 1 - 1
format.py

@@ -21,7 +21,7 @@ def run_yapf(target):
 
 
 def main():
 def main():
   if len(sys.argv) < 2:
   if len(sys.argv) < 2:
-    print("Usage: python format.py <directory_or_file>")
+    print("Usage: python3 format.py <directory_or_file> e.g. python3 format.py ./exo")
     sys.exit(1)
     sys.exit(1)
 
 
   target = sys.argv[1]
   target = sys.argv[1]

+ 0 - 5
lint.sh

@@ -1,5 +0,0 @@
-#!/bin/bash
-
-pip3 install -e '.[linting]'
-python3 -m ruff check .
-python3 -m pylint .

+ 0 - 7
pyproject.toml

@@ -1,7 +0,0 @@
-[tool.pylint.format]
-indent-string = '  '
-max-line-length = 200
-
-[tool.autopep8]
-max_line_length = 200
-indent_size = 2

+ 0 - 43
ruff.toml

@@ -1,43 +0,0 @@
-indent-width = 2
-preview = true
-target-version = "py312"
-
-lint.select = [
-  "F",  # Pyflakes
-  "W6",
-  "E71",
-  "E72",
-  "E112",   # no-indented-block
-  "E113",   # unexpected-indentation
-  # "E124",
-  "E203",   # whitespace-before-punctuation
-  "E272",   # multiple-spaces-before-keyword
-  "E303",   # too-many-blank-lines
-  "E304",   # blank-line-after-decorator
-  "E501",   # line-too-long
-  # "E502",
-  "E702",   # multiple-statements-on-one-line-semicolon
-  "E703",   # useless-semicolon
-  "E731",   # lambda-assignment
-  "W191",   # tab-indentation
-  "W291",   # trailing-whitespace
-  "W293",   # blank-line-with-whitespace
-  "UP039",  # unnecessary-class-parentheses
-  "C416",   # unnecessary-comprehension
-  "RET506", # superfluous-else-raise
-  "RET507", # superfluous-else-continue
-  "A",      # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
-  "SIM105", # suppressible-exception
-  "FURB110",# if-exp-instead-of-or-operator
-]
-
-line-length = 200
-
-exclude = [
-  "docs/",
-  "examples/",
-  "extra/",
-  "exo/networking/grpc/node_service_pb2.py",
-  "exo/networking/grpc/node_service_pb2_grpc.py",
-  "exo/helpers.py",
-]

+ 60 - 0
scripts/build_exo.py

@@ -0,0 +1,60 @@
+import site
+import subprocess
+import sys
+import os 
+import pkgutil
+
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--onefile",
+        "--python-flag=no_site"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "--macos-signed-app-name=com.exolabs.exo",
+            "--macos-sign-identity=auto",
+            "--macos-sign-notarization",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 7 - 0
scripts/compile_grpc.sh

@@ -0,0 +1,7 @@
+#!/bin/bash
+source ./install.sh
+pushd exo/networking/grpc
+python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
+sed -i "s/import\ node_service_pb2/from . &/" node_service_pb2_grpc.py
+popd
+

+ 10 - 13
setup.py

@@ -5,40 +5,37 @@ from setuptools import find_packages, setup
 
 
 # Base requirements for all platforms
 # Base requirements for all platforms
 install_requires = [
 install_requires = [
-  "aiohttp==3.10.2",
+  "aiohttp==3.10.11",
   "aiohttp_cors==0.7.0",
   "aiohttp_cors==0.7.0",
   "aiofiles==24.1.0",
   "aiofiles==24.1.0",
-  "grpcio==1.64.1",
-  "grpcio-tools==1.64.1",
+  "grpcio==1.68.0",
+  "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "netifaces==0.11.0",
   "numpy==2.0.0",
   "numpy==2.0.0",
+  "nuitka==2.4.10",
   "nvidia-ml-py==12.560.30",
   "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
   "prometheus-client==0.20.0",
-  "protobuf==5.27.1",
+  "protobuf==5.28.1",
   "psutil==6.0.0",
   "psutil==6.0.0",
   "pydantic==2.9.2",
   "pydantic==2.9.2",
   "requests==2.32.3",
   "requests==2.32.3",
   "rich==13.7.1",
   "rich==13.7.1",
-  "safetensors==0.4.3",
   "tenacity==9.0.0",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
   "tqdm==4.66.4",
-  "transformers==4.43.3",
+  "transformers==4.46.3",
   "uuid==1.30",
   "uuid==1.30",
-  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
+  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
 ]
 ]
 
 
 extras_require = {
 extras_require = {
-  "linting": [
-    "pylint==3.2.6",
-    "ruff==0.5.5",
-    "mypy==1.11.0",
+  "formatting": [
     "yapf==0.40.2",
     "yapf==0.40.2",
   ],
   ],
   "apple_silicon": [
   "apple_silicon": [
-    "mlx==0.18.0",
-    "mlx-lm==0.18.2",
+    "mlx==0.20.0",
+    "mlx-lm==0.19.3",
   ],
   ],
 }
 }
 
 

+ 1 - 1
test/reconnect.sh

@@ -1,7 +1,7 @@
 #!/bin/bash
 #!/bin/bash
 
 
 echo "Starting node 1"
 echo "Starting node 1"
-DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
+DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
 PID1=$!
 PID1=$!
 echo "Started node 1 PID: $PID1"
 echo "Started node 1 PID: $PID1"
 echo "Starting node 2"
 echo "Starting node 2"

+ 121 - 0
test/test_model_helpers.py

@@ -0,0 +1,121 @@
+import unittest
+from exo.models import get_supported_models, model_cards
+from exo.inference.inference_engine import inference_engine_classes
+from typing import NamedTuple
+
+class TestCase(NamedTuple):
+  name: str
+  engine_lists: list  # Will contain short names, will be mapped to class names
+  expected_models_contains: list
+  min_count: int | None
+  exact_count: int | None
+  max_count: int | None
+
+# Helper function to map short names to class names
+def expand_engine_lists(engine_lists):
+  def map_engine(engine):
+    return inference_engine_classes.get(engine, engine)  # Return original name if not found
+
+  return [[map_engine(engine) for engine in sublist]
+          for sublist in engine_lists]
+
+test_cases = [
+  TestCase(
+    name="single_mlx_engine",
+    engine_lists=[["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="single_tinygrad_engine",
+    engine_lists=[["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="multiple_engines_or",
+    engine_lists=[["mlx", "tinygrad"], ["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="multiple_engines_all",
+    engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="distinct_engine_lists",
+    engine_lists=[["mlx"], ["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="no_engines",
+    engine_lists=[],
+    expected_models_contains=None,
+    min_count=None,
+    exact_count=len(model_cards),
+    max_count=None
+  ),
+  TestCase(
+    name="nonexistent_engine",
+    engine_lists=[["NonexistentEngine"]],
+    expected_models_contains=[],
+    min_count=None,
+    exact_count=0,
+    max_count=None
+  ),
+  TestCase(
+    name="dummy_engine",
+    engine_lists=[["dummy"]],
+    expected_models_contains=["dummy"],
+    min_count=None,
+    exact_count=1,
+    max_count=None
+  ),
+]
+
+class TestModelHelpers(unittest.TestCase):
+  def test_get_supported_models(self):
+    for case in test_cases:
+      with self.subTest(f"{case.name}_short_names"):
+        result = get_supported_models(case.engine_lists)
+        self._verify_results(case, result)
+
+      with self.subTest(f"{case.name}_class_names"):
+        class_name_lists = expand_engine_lists(case.engine_lists)
+        result = get_supported_models(class_name_lists)
+        self._verify_results(case, result)
+
+  def _verify_results(self, case, result):
+    if case.expected_models_contains:
+      for model in case.expected_models_contains:
+        self.assertIn(model, result)
+
+    if case.min_count:
+      self.assertGreater(len(result), case.min_count)
+
+    if case.exact_count is not None:
+      self.assertEqual(len(result), case.exact_count)
+
+    # Special case for distinct lists test
+    if case.name == "distinct_engine_lists":
+      self.assertLess(len(result), 10)
+      self.assertNotIn("mistral-nemo", result)
+
+    if case.max_count:
+      self.assertLess(len(result), case.max_count)
+
+if __name__ == '__main__':
+  unittest.main()

+ 8 - 3
test/test_tokenizers.py

@@ -1,7 +1,7 @@
 import os
 import os
 import re
 import re
 from transformers import AutoTokenizer, AutoProcessor
 from transformers import AutoTokenizer, AutoProcessor
-from exo.models import model_base_shards
+from exo.models import model_cards
 
 
 
 
 def test_tokenizer(name, tokenizer, verbose=False):
 def test_tokenizer(name, tokenizer, verbose=False):
@@ -24,9 +24,14 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
-models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if not ignore_pattern.match(shard.model_id)]
+models = []
+for model_id in model_cards:
+  for engine_type, repo_id in model_cards[model_id].get("repo", {}).items():
+    if not ignore_pattern.match(repo_id):
+      models.append(repo_id)
+models = list(set(models))
 
 
 verbose = os.environ.get("VERBOSE", "0").lower() == "1"
 verbose = os.environ.get("VERBOSE", "0").lower() == "1"
 for m in models:
 for m in models:

Some files were not shown because too many files changed in this diff