Explorar o código

Merge branch 'main' into pipsize

Alex Cheema hai 6 meses
pai
achega
29e5363ebe
Modificáronse 75 ficheiros con 2318 adicións e 1439 borrados
  1. 78 38
      .circleci/config.yml
  2. 3 2
      .gitignore
  3. 0 472
      .pylintrc
  4. 26 11
      README.md
  5. 18 2
      configure_mlx.sh
  6. BIN=BIN
      docs/exo-rounded.png
  7. BIN=BIN
      docs/exo-screenshot.png
  8. 1 1
      examples/astra/astra/ContentView.swift
  9. 1 1
      examples/chatgpt_api.sh
  10. 1 1
      exo/__init__.py
  11. 76 56
      exo/api/chatgpt_api.py
  12. 39 8
      exo/download/hf/hf_helpers.py
  13. 8 6
      exo/download/hf/hf_shard_download.py
  14. 11 1
      exo/download/shard_download.py
  15. 20 1
      exo/helpers.py
  16. 11 12
      exo/inference/debug_inference_engine.py
  17. 43 0
      exo/inference/dummy_inference_engine.py
  18. 28 4
      exo/inference/inference_engine.py
  19. 1 1
      exo/inference/mlx/models/base.py
  20. 1 1
      exo/inference/mlx/models/deepseek_v2.py
  21. 118 0
      exo/inference/mlx/models/gemma2.py
  22. 4 4
      exo/inference/mlx/models/llama.py
  23. 2 1
      exo/inference/mlx/models/qwen2.py
  24. 47 20
      exo/inference/mlx/sharded_inference_engine.py
  25. 0 86
      exo/inference/mlx/sharded_model.py
  26. 14 8
      exo/inference/mlx/sharded_utils.py
  27. 42 0
      exo/inference/mlx/stateful_model.py
  28. 4 4
      exo/inference/mlx/test_sharded_llama.py
  29. 2 2
      exo/inference/mlx/test_sharded_llava.py
  30. 53 0
      exo/inference/test_dummy_inference_engine.py
  31. 20 28
      exo/inference/test_inference_engine.py
  32. 34 36
      exo/inference/tinygrad/inference.py
  33. 61 36
      exo/inference/tinygrad/models/llama.py
  34. 42 0
      exo/inference/tinygrad/stateful_model.py
  35. 15 0
      exo/inference/tokenizers.py
  36. 98 47
      exo/main.py
  37. 126 40
      exo/models.py
  38. 12 9
      exo/networking/grpc/grpc_peer_handle.py
  39. 3 5
      exo/networking/grpc/grpc_server.py
  40. 2 5
      exo/networking/grpc/node_service.proto
  41. 0 0
      exo/networking/grpc/node_service_pb2.py
  42. 0 0
      exo/networking/manual/__init__.py
  43. 71 0
      exo/networking/manual/manual_discovery.py
  44. 31 0
      exo/networking/manual/network_topology_config.py
  45. 17 0
      exo/networking/manual/test_data/invalid_config.json
  46. 0 0
      exo/networking/manual/test_data/invalid_json.json
  47. 32 0
      exo/networking/manual/test_data/test_config.json
  48. 18 0
      exo/networking/manual/test_data/test_config_single_node.json
  49. 103 0
      exo/networking/manual/test_manual_discovery.py
  50. 49 0
      exo/networking/manual/test_network_topology_config.py
  51. 3 2
      exo/networking/peer_handle.py
  52. 42 18
      exo/networking/tailscale/tailscale_discovery.py
  53. 49 20
      exo/networking/tailscale/tailscale_helpers.py
  54. 2 0
      exo/networking/tailscale/test_tailscale_discovery.py
  55. 1 0
      exo/networking/udp/test_udp_discovery.py
  56. 54 29
      exo/networking/udp/udp_discovery.py
  57. 2 2
      exo/orchestration/node.py
  58. 139 110
      exo/orchestration/standard_node.py
  59. 150 64
      exo/tinychat/index.css
  60. 51 32
      exo/tinychat/index.html
  61. 149 74
      exo/tinychat/index.js
  62. 44 41
      exo/tinychat/update_deps.py
  63. 29 16
      exo/topology/device_capabilities.py
  64. 1 1
      exo/viz/topology_viz.py
  65. 1 1
      extra/start_openwebui.sh
  66. 1 1
      format.py
  67. 0 5
      lint.sh
  68. 0 7
      pyproject.toml
  69. 0 43
      ruff.toml
  70. 58 0
      scripts/build_exo.py
  71. 7 0
      scripts/compile_grpc.sh
  72. 19 20
      setup.py
  73. 1 1
      test/reconnect.sh
  74. 121 0
      test/test_model_helpers.py
  75. 8 3
      test/test_tokenizers.py

+ 78 - 38
.circleci/config.yml

@@ -10,18 +10,28 @@ commands:
         type: string
         type: string
       model_id:
       model_id:
         type: string
         type: string
+      expected_output:
+        type: string
+      prompt:
+        type: string
     steps:
     steps:
       - run:
       - run:
           name: Run chatgpt api integration test (<<parameters.inference_engine>>, <<parameters.model_id>>)
           name: Run chatgpt api integration test (<<parameters.inference_engine>>, <<parameters.model_id>>)
           command: |
           command: |
             source env/bin/activate
             source env/bin/activate
 
 
+            # Set CLANG=1 for tinygrad only
+            if [ "<<parameters.inference_engine>>" = "tinygrad" ]; then
+              pip install llvmlite
+              export TOKENIZERS_PARALLELISM=true SUPPORT_BF16=0 CLANG=1
+            fi
+
             # Start first instance
             # Start first instance
-            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 2>&1 | tee output1.log &
+            HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 --disable-tui 2>&1 | tee output1.log &
             PID1=$!
             PID1=$!
 
 
             # Start second instance
             # Start second instance
-            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 2>&1 | tee output2.log &
+            HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 --disable-tui 2>&1 | tee output2.log &
             PID2=$!
             PID2=$!
 
 
             # Wait for discovery
             # Wait for discovery
@@ -49,7 +59,7 @@ commands:
               -H "Content-Type: application/json" \
               -H "Content-Type: application/json" \
               -d '{
               -d '{
                 "model": "<<parameters.model_id>>",
                 "model": "<<parameters.model_id>>",
-                "messages": [{"role": "user", "content": "Keep responses concise. Who was the king of pop?"}],
+                "messages": [{"role": "user", "content": "<<parameters.prompt>>"}],
                 "temperature": 0.7
                 "temperature": 0.7
               }')
               }')
             echo "Response 1: $response_1"
             echo "Response 1: $response_1"
@@ -62,7 +72,7 @@ commands:
               -H "Content-Type: application/json" \
               -H "Content-Type: application/json" \
               -d '{
               -d '{
                 "model": "<<parameters.model_id>>",
                 "model": "<<parameters.model_id>>",
-                "messages": [{"role": "user", "content": "Keep responses concise. Who was the king of pop?"}],
+                "messages": [{"role": "user", "content": "<<parameters.prompt>>"}],
                 "temperature": 0.7
                 "temperature": 0.7
               }')
               }')
             echo "Response 2: $response_2"
             echo "Response 2: $response_2"
@@ -74,8 +84,8 @@ commands:
             kill $PID1 $PID2
             kill $PID1 $PID2
 
 
             echo ""
             echo ""
-            if ! echo "$response_1" | grep -q "Michael Jackson" || ! echo "$response_2" | grep -q "Michael Jackson"; then
-              echo "Test failed: Response does not contain 'Michael Jackson'"
+            if ! echo "$response_1" | grep -q "<<parameters.expected_output>>" || ! echo "$response_2" | grep -q "<<parameters.expected_output>>"; then
+              echo "Test failed: Response does not contain '<<parameters.expected_output>>'"
               echo "Response 1: $response_1"
               echo "Response 1: $response_1"
               echo ""
               echo ""
               echo "Response 2: $response_2"
               echo "Response 2: $response_2"
@@ -85,14 +95,14 @@ commands:
               cat output2.log
               cat output2.log
               exit 1
               exit 1
             else
             else
-              echo "Test passed: Response from both nodes contains 'Michael Jackson'"
+              echo "Test passed: Response from both nodes contains '<<parameters.expected_output>>'"
             fi
             fi
 
 
 jobs:
 jobs:
   unit_test:
   unit_test:
     macos:
     macos:
       xcode: "16.0.0"
       xcode: "16.0.0"
-    resource_class: macos.m1.large.gen1
+    resource_class: m2pro.large
     steps:
     steps:
       - checkout
       - checkout
       - run:
       - run:
@@ -116,10 +126,11 @@ jobs:
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
             echo "Running tokenizer tests..."
             echo "Running tokenizer tests..."
             python3 ./test/test_tokenizers.py
             python3 ./test/test_tokenizers.py
+            python3 ./test/test_model_helpers.py
 
 
   discovery_integration_test:
   discovery_integration_test:
     macos:
     macos:
-      xcode: "15.4.0"
+      xcode: "16.0.0"
     steps:
     steps:
       - checkout
       - checkout
       - run:
       - run:
@@ -138,9 +149,9 @@ jobs:
           name: Run discovery integration test
           name: Run discovery integration test
           command: |
           command: |
             source env/bin/activate
             source env/bin/activate
-            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 > output1.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --disable-tui > output1.log 2>&1 &
             PID1=$!
             PID1=$!
-            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 > output2.log 2>&1 &
+            DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --disable-tui > output2.log 2>&1 &
             PID2=$!
             PID2=$!
             sleep 10
             sleep 10
             kill $PID1 $PID2
             kill $PID1 $PID2
@@ -158,8 +169,8 @@ jobs:
 
 
   chatgpt_api_integration_test_mlx:
   chatgpt_api_integration_test_mlx:
     macos:
     macos:
-      xcode: "15.4.0"
-    resource_class: macos.m1.large.gen1
+      xcode: "16.0.0"
+    resource_class: m2pro.large
     steps:
     steps:
       - checkout
       - checkout
       - run:
       - run:
@@ -176,37 +187,65 @@ jobs:
             pip install .
             pip install .
       - run_chatgpt_api_test:
       - run_chatgpt_api_test:
           inference_engine: mlx
           inference_engine: mlx
-          model_id: llama-3.1-8b
+          model_id: llama-3.2-1b
+          prompt: "Keep responses concise. Who was the king of pop?"
+          expected_output: "Michael Jackson"
+
+  chatgpt_api_integration_test_dummy:
+    macos:
+      xcode: "16.0.0"
+    resource_class: m2pro.large
+    steps:
+      - checkout
+      - run:
+          name: Set up Python
+          command: |
+            brew install python@3.12
+            python3.12 -m venv env
+            source env/bin/activate
+      - run:
+          name: Install dependencies
+          command: |
+            source env/bin/activate
+            pip install --upgrade pip
+            pip install .
+      - run_chatgpt_api_test:
+          inference_engine: dummy
+          model_id: dummy-model
+          prompt: "Dummy prompt."
+          expected_output: "dummy"
 
 
   test_macos_m1:
   test_macos_m1:
     macos:
     macos:
-      xcode: "15.4.0"
-    resource_class: macos.m1.large.gen1
+      xcode: "16.0.0"
+    resource_class: m2pro.large
     steps:
     steps:
       - checkout
       - checkout
       - run: system_profiler SPHardwareDataType
       - run: system_profiler SPHardwareDataType
 
 
-  # chatgpt_api_integration_test_tinygrad:
-  #   macos:
-  #     xcode: "15.4.0"
-  #   resource_class: macos.m1.large.gen1
-  #   steps:
-  #     - checkout
-  #     - run:
-  #         name: Set up Python
-  #         command: |
-  #           brew install python@3.12
-  #           python3.12 -m venv env
-  #           source env/bin/activate
-  #     - run:
-  #         name: Install dependencies
-  #         command: |
-  #           source env/bin/activate
-  #           pip install --upgrade pip
-  #           pip install .
-  #     - run_chatgpt_api_test:
-  #         inference_engine: tinygrad
-  #         model_id: llama-3-8b
+  chatgpt_api_integration_test_tinygrad:
+    macos:
+      xcode: "16.0.0"
+    resource_class: m2pro.large
+    steps:
+      - checkout
+      - run:
+          name: Set up Python
+          command: |
+            brew install python@3.12
+            python3.12 -m venv env
+            source env/bin/activate
+      - run:
+          name: Install dependencies
+          command: |
+            source env/bin/activate
+            pip install --upgrade pip
+            pip install .
+      - run_chatgpt_api_test:
+          inference_engine: tinygrad
+          model_id: llama-3.2-1b
+          prompt: "Keep responses concise. Who was the king of pop?"
+          expected_output: "Michael Jackson"
 
 
 workflows:
 workflows:
   version: 2
   version: 2
@@ -215,5 +254,6 @@ workflows:
       - unit_test
       - unit_test
       - discovery_integration_test
       - discovery_integration_test
       - chatgpt_api_integration_test_mlx
       - chatgpt_api_integration_test_mlx
+      - chatgpt_api_integration_test_tinygrad
+      - chatgpt_api_integration_test_dummy
       - test_macos_m1
       - test_macos_m1
-      # - chatgpt_api_integration_test_tinygrad

+ 3 - 2
.gitignore

@@ -1,9 +1,10 @@
 __pycache__/
 __pycache__/
-.venv
+.venv*
 test_weights.npz
 test_weights.npz
 .exo_used_ports
 .exo_used_ports
 .exo_node_id
 .exo_node_id
 .idea
 .idea
+.DS_Store
 
 
 # Byte-compiled / optimized / DLL files
 # Byte-compiled / optimized / DLL files
 __pycache__/
 __pycache__/
@@ -15,7 +16,6 @@ __pycache__/
 
 
 # Distribution / packaging
 # Distribution / packaging
 /.Python
 /.Python
-/build/
 /develop-eggs/
 /develop-eggs/
 /dist/
 /dist/
 /downloads/
 /downloads/
@@ -170,3 +170,4 @@ cython_debug/
 #.idea/
 #.idea/
 
 
 **/*.xcodeproj/*
 **/*.xcodeproj/*
+.aider*

+ 0 - 472
.pylintrc

@@ -1,472 +0,0 @@
-[MASTER]
-
-# A comma-separated list of package or module names from where C extensions may
-# be loaded. Extensions are loading into the active Python interpreter and may
-# run arbitrary code
-extension-pkg-whitelist=scipy,cereal.messaging.messaging_pyx,PyQt5,av
-
-# Add files or directories to the blacklist. They should be base names, not
-# paths.
-ignore=CVS
-
-# Add files or directories matching the regex patterns to the blacklist. The
-# regex matches against base names, not paths.
-ignore-patterns=.*node_service_pb2.*
-
-# Python code to execute, usually for sys.path manipulation such as
-# pygtk.require().
-#init-hook=
-
-# Use multiple processes to speed up Pylint.
-jobs=4
-
-# List of plugins (as comma separated values of python modules names) to load,
-# usually to register additional checkers.
-load-plugins=
-
-# Pickle collected data for later comparisons.
-persistent=yes
-
-# Specify a configuration file.
-#rcfile=
-
-# When enabled, pylint would attempt to guess common misconfiguration and emit
-# user-friendly hints instead of false-positive error messages
-suggestion-mode=yes
-
-# Allow loading of arbitrary C extensions. Extensions are imported into the
-# active Python interpreter and may run arbitrary code.
-unsafe-load-any-extension=no
-
-
-[MESSAGES CONTROL]
-
-# Only show warnings with the listed confidence levels. Leave empty to show
-# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
-confidence=
-
-# Disable the message, report, category or checker with the given id(s). You
-# can either give multiple identifiers separated by comma (,) or put this
-# option multiple times (only on the command line, not in the configuration
-# file where it should appear only once).You can also use "--disable=all" to
-# disable everything first and then reenable specific checks. For example, if
-# you want to run only the similarities checker, you can use "--disable=all
-# --enable=similarities". If you want to run only the classes checker, but have
-# no Warning level messages displayed, use"--disable=all --enable=classes
-# --disable=W"
-disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401
-# E1101 for function binding
-# W0221 for Function class
-# W0105 for comment strings
-# E0401 for missing imports
-
-# Enable the message, report, category or checker with the given id(s). You can
-# either give multiple identifier separated by comma (,) or put this option
-# multiple time (only on the command line, not in the configuration file where
-# it should appear only once). See also the "--disable" option for examples.
-enable=c-extension-no-member,use-a-generator, no-else-return
-
-
-[REPORTS]
-
-# Python expression which should return a note less than 10 (10 is the highest
-# note). You have access to the variables errors warning, statement which
-# respectively contain the number of errors / warnings messages and the total
-# number of statements analyzed. This is used by the global evaluation report
-# (RP0004).
-evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
-
-# Template used to display messages. This is a python new-style format string
-# used to format the message information. See doc for all details
-#msg-template=
-
-# Set the output format. Available formats are text, parseable, colorized, json
-# and msvs (visual studio).You can also give a reporter class, eg
-# mypackage.mymodule.MyReporterClass.
-output-format=text
-
-# Tells whether to display a full report or only the messages
-reports=no
-
-# Activate the evaluation score.
-score=yes
-
-
-[REFACTORING]
-
-# Maximum number of nested blocks for function / method body
-max-nested-blocks=5
-
-# Complete name of functions that never returns. When checking for
-# inconsistent-return-statements if a never returning function is called then
-# it will be considered as an explicit return statement and no message will be
-# printed.
-never-returning-functions=optparse.Values,sys.exit
-
-
-[LOGGING]
-
-# Logging modules to check that the string format arguments are in logging
-# function parameter format
-logging-modules=logging
-
-
-[SPELLING]
-
-# Limits count of emitted suggestions for spelling mistakes
-max-spelling-suggestions=4
-
-# Spelling dictionary name. Available dictionaries: none. To make it working
-# install python-enchant package.
-spelling-dict=
-
-# List of comma separated words that should not be checked.
-spelling-ignore-words=
-
-# A path to a file that contains private dictionary; one word per line.
-spelling-private-dict-file=
-
-# Tells whether to store unknown words to indicated private dictionary in
-# --spelling-private-dict-file option instead of raising a message.
-spelling-store-unknown-words=no
-
-
-[MISCELLANEOUS]
-
-# List of note tags to take in consideration, separated by a comma.
-notes=FIXME,
-      XXX,
-      TODO
-
-
-[SIMILARITIES]
-
-# Ignore comments when computing similarities.
-ignore-comments=yes
-
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
-
-# Ignore imports when computing similarities.
-ignore-imports=no
-
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-
-
-[TYPECHECK]
-
-# List of decorators that produce context managers, such as
-# contextlib.contextmanager. Add to this list to register other decorators that
-# produce valid context managers.
-contextmanager-decorators=contextlib.contextmanager
-
-# List of members which are set dynamically and missed by pylint inference
-# system, and so shouldn't trigger E1101 when accessed. Python regular
-# expressions are accepted.
-generated-members=capnp.* cereal.* pygame.* zmq.* setproctitle.* smbus2.* usb1.* serial.* cv2.* ft4222.* carla.*
-
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members=yes
-
-# This flag controls whether pylint should warn about no-member and similar
-# checks whenever an opaque object is returned when inferring. The inference
-# can return multiple potential results while evaluating a Python object, but
-# some branches might not be evaluated, which results in partial inference. In
-# that case, it might be useful to still emit no-member and other checks for
-# the rest of the inferred objects.
-ignore-on-opaque-inference=yes
-
-# List of class names for which member attributes should not be checked (useful
-# for classes with dynamically set attributes). This supports the use of
-# qualified names.
-ignored-classes=optparse.Values,thread._local,_thread._local
-
-# List of module names for which member attributes should not be checked
-# (useful for modules/projects where namespaces are manipulated during runtime
-# and thus existing member attributes cannot be deduced by static analysis. It
-# supports qualified module names, as well as Unix pattern matching.
-ignored-modules=flask setproctitle usb1 flask.ext.socketio smbus2 usb1.*
-
-# Show a hint with possible names when a member name was not found. The aspect
-# of finding the hint is based on edit distance.
-missing-member-hint=yes
-
-# The minimum edit distance a name should have in order to be considered a
-# similar match for a missing member name.
-missing-member-hint-distance=1
-
-# The total number of similar names that should be taken in consideration when
-# showing a hint for a missing member.
-missing-member-max-choices=1
-
-
-[VARIABLES]
-
-# List of additional names supposed to be defined in builtins. Remember that
-# you should avoid to define new builtins when possible.
-additional-builtins=
-
-# Tells whether unused global variables should be treated as a violation.
-allow-global-unused-variables=yes
-
-# List of strings which can identify a callback function by name. A callback
-# name must start or end with one of those strings.
-callbacks=cb_,
-          _cb
-
-# A regular expression matching the name of dummy variables (i.e. expectedly
-# not used).
-dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
-
-# Argument names that match this expression will be ignored. Default to name
-# with leading underscore
-ignored-argument-names=_.*|^ignored_|^unused_
-
-# Tells whether we should check for unused import in __init__ files.
-init-import=no
-
-# List of qualified module names which can have objects that can redefine
-# builtins.
-redefining-builtins-modules=six.moves,past.builtins,future.builtins
-
-
-[FORMAT]
-
-# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
-expected-line-ending-format=
-
-# Regexp for a line that is allowed to be longer than the limit.
-ignore-long-lines=^\s*(# )?<?https?://\S+>?$
-
-# Number of spaces of indent required inside a hanging  or continued line.
-indent-after-paren=4
-
-# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
-# tab).
-indent-string='  '
-
-# Maximum number of characters on a single line.
-max-line-length=150
-
-# Maximum number of lines in a module
-max-module-lines=1000
-
-# Allow the body of a class to be on the same line as the declaration if body
-# contains single statement.
-single-line-class-stmt=no
-
-# Allow the body of an if to be on the same line as the test if there is no
-# else.
-single-line-if-stmt=no
-
-
-[BASIC]
-
-# Naming style matching correct argument names
-argument-naming-style=snake_case
-
-# Regular expression matching correct argument names. Overrides argument-
-# naming-style
-#argument-rgx=
-
-# Naming style matching correct attribute names
-attr-naming-style=snake_case
-
-# Regular expression matching correct attribute names. Overrides attr-naming-
-# style
-#attr-rgx=
-
-# Bad variable names which should always be refused, separated by a comma
-bad-names=foo,
-          bar,
-          baz,
-          toto,
-          tutu,
-          tata
-
-# Naming style matching correct class attribute names
-class-attribute-naming-style=any
-
-# Regular expression matching correct class attribute names. Overrides class-
-# attribute-naming-style
-#class-attribute-rgx=
-
-# Naming style matching correct class names
-class-naming-style=PascalCase
-
-# Regular expression matching correct class names. Overrides class-naming-style
-#class-rgx=
-
-# Naming style matching correct constant names
-const-naming-style=UPPER_CASE
-
-# Regular expression matching correct constant names. Overrides const-naming-
-# style
-#const-rgx=
-
-# Minimum line length for functions/classes that require docstrings, shorter
-# ones are exempt.
-docstring-min-length=-1
-
-# Naming style matching correct function names
-function-naming-style=snake_case
-
-# Regular expression matching correct function names. Overrides function-
-# naming-style
-#function-rgx=
-
-# Good variable names which should always be accepted, separated by a comma
-good-names=i,
-           j,
-           k,
-           ex,
-           Run,
-           _
-
-# Include a hint for the correct naming format with invalid-name
-include-naming-hint=no
-
-# Naming style matching correct inline iteration names
-inlinevar-naming-style=any
-
-# Regular expression matching correct inline iteration names. Overrides
-# inlinevar-naming-style
-#inlinevar-rgx=
-
-# Naming style matching correct method names
-method-naming-style=snake_case
-
-# Regular expression matching correct method names. Overrides method-naming-
-# style
-#method-rgx=
-
-# Naming style matching correct module names
-module-naming-style=snake_case
-
-# Regular expression matching correct module names. Overrides module-naming-
-# style
-#module-rgx=
-
-# Colon-delimited sets of names that determine each other's naming style when
-# the name regexes allow several styles.
-name-group=
-
-# Regular expression which should only match function or class names that do
-# not require a docstring.
-no-docstring-rgx=^_
-
-# List of decorators that produce properties, such as abc.abstractproperty. Add
-# to this list to register other decorators that produce valid properties.
-property-classes=abc.abstractproperty
-
-# Naming style matching correct variable names
-variable-naming-style=snake_case
-
-# Regular expression matching correct variable names. Overrides variable-
-# naming-style
-#variable-rgx=
-
-
-[DESIGN]
-
-# Maximum number of arguments for function / method
-max-args=5
-
-# Maximum number of attributes for a class (see R0902).
-max-attributes=7
-
-# Maximum number of boolean expressions in a if statement
-max-bool-expr=5
-
-# Maximum number of branch for function / method body
-max-branches=12
-
-# Maximum number of locals for function / method body
-max-locals=15
-
-# Maximum number of parents for a class (see R0901).
-max-parents=7
-
-# Maximum number of public methods for a class (see R0904).
-max-public-methods=20
-
-# Maximum number of return / yield for function / method body
-max-returns=6
-
-# Maximum number of statements in function / method body
-max-statements=50
-
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=2
-
-
-[CLASSES]
-
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,
-                      __new__,
-                      setUp
-
-# List of member names, which should be excluded from the protected access
-# warning.
-exclude-protected=_asdict,
-                  _fields,
-                  _replace,
-                  _source,
-                  _make
-
-# List of valid names for the first argument in a class method.
-valid-classmethod-first-arg=cls
-
-# List of valid names for the first argument in a metaclass class method.
-valid-metaclass-classmethod-first-arg=mcs
-
-
-[IMPORTS]
-
-# Allow wildcard imports from modules that define __all__.
-allow-wildcard-with-all=no
-
-# Analyse import fallback blocks. This can be used to support both Python 2 and
-# 3 compatible code, which means that the block might have code that exists
-# only in one or another interpreter, leading to false positives when analysed.
-analyse-fallback-blocks=no
-
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules=regsub,
-                   TERMIOS,
-                   Bastion,
-                   rexec
-
-# Create a graph of external dependencies in the given file (report RP0402 must
-# not be disabled)
-ext-import-graph=
-
-# Create a graph of every (i.e. internal and external) dependencies in the
-# given file (report RP0402 must not be disabled)
-import-graph=
-
-# Create a graph of internal dependencies in the given file (report RP0402 must
-# not be disabled)
-int-import-graph=
-
-# Force import order to recognize a module as part of the standard
-# compatibility libraries.
-known-standard-library=
-
-# Force import order to recognize a module as part of a third party library.
-known-third-party=enchant
-
-[STRING]
-
-# This flag controls whether the implicit-str-concat should generate a warning
-# on implicit string concatenation in sequences defined over several lines.
-check-str-concat-over-line-jumps=yes
-
-[EXCEPTIONS]
-
-# Exceptions that will emit a warning when being caught. Defaults to
-# "Exception"
-overgeneral-exceptions=builtins.Exception

+ 26 - 11
README.md

@@ -58,12 +58,7 @@ Unlike other distributed inference frameworks, exo does not use a master-worker
 
 
 Exo supports different [partitioning strategies](exo/topology/partitioning_strategy.py) to split up a model across devices. The default partitioning strategy is [ring memory weighted partitioning](exo/topology/ring_memory_weighted_partitioning_strategy.py). This runs an inference in a ring where each device runs a number of model layers proportional to the memory of the device.
 Exo supports different [partitioning strategies](exo/topology/partitioning_strategy.py) to split up a model across devices. The default partitioning strategy is [ring memory weighted partitioning](exo/topology/ring_memory_weighted_partitioning_strategy.py). This runs an inference in a ring where each device runs a number of model layers proportional to the memory of the device.
 
 
-<p>
-    <picture>
-        <img alt="ring topology" src="docs/ring-topology.png" width="30%" height="30%">
-    </picture>
-</p>
-
+!["A screenshot of exo running 5 nodes](docs/exo-screenshot.png)
 
 
 ## Installation
 ## Installation
 
 
@@ -126,14 +121,14 @@ exo
 
 
 That's it! No configuration required - exo will automatically discover the other device(s).
 That's it! No configuration required - exo will automatically discover the other device(s).
 
 
-exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000
+exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:52415
 
 
-For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Examples with curl:
+For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:52415/v1/chat/completions. Examples with curl:
 
 
 #### Llama 3.2 3B:
 #### Llama 3.2 3B:
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llama-3.2-3b",
      "model": "llama-3.2-3b",
@@ -145,7 +140,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llama 3.1 405B:
 #### Llama 3.1 405B:
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llama-3.1-405b",
      "model": "llama-3.1-405b",
@@ -157,7 +152,7 @@ curl http://localhost:8000/v1/chat/completions \
 #### Llava 1.5 7B (Vision Language Model):
 #### Llava 1.5 7B (Vision Language Model):
 
 
 ```sh
 ```sh
-curl http://localhost:8000/v1/chat/completions \
+curl http://localhost:52415/v1/chat/completions \
   -H "Content-Type: application/json" \
   -H "Content-Type: application/json" \
   -d '{
   -d '{
      "model": "llava-1.5-7b-hf",
      "model": "llava-1.5-7b-hf",
@@ -213,6 +208,12 @@ With a custom prompt:
 exo run llama-3.2-3b --prompt "What is the meaning of exo?"
 exo run llama-3.2-3b --prompt "What is the meaning of exo?"
 ```
 ```
 
 
+### Model Storage
+
+Models by default are stored in `~/.cache/huggingface/hub`.
+
+You can set a different model storage location by setting the `HF_HOME` env var.
+
 ## Debugging
 ## Debugging
 
 
 Enable debug logs with the DEBUG environment variable (0-9).
 Enable debug logs with the DEBUG environment variable (0-9).
@@ -227,6 +228,20 @@ For the **tinygrad** inference engine specifically, there is a separate DEBUG fl
 TINYGRAD_DEBUG=2 exo
 TINYGRAD_DEBUG=2 exo
 ```
 ```
 
 
+## Formatting
+
+We use [yapf](https://github.com/google/yapf) to format the code. To format the code, first install the formatting requirements:
+
+```sh
+pip3 install -e '.[formatting]'
+```
+
+Then run the formatting script:
+
+```sh
+python3 format.py ./exo
+```
+
 ## Known Issues
 ## Known Issues
 
 
 - On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
 - On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:

+ 18 - 2
configure_mlx.sh

@@ -1,2 +1,18 @@
-sudo sysctl iogpu.wired_lwm_mb=400000
-sudo sysctl iogpu.wired_limit_mb=180000
+#!/bin/bash
+
+# Get the total memory in MB
+TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
+
+# Set WIRED_LIMIT_MB to 80%
+WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
+# Set  WIRED_LWM_MB to 70%
+WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
+
+# Display the calculated values
+echo "Total memory: $TOTAL_MEM_MB MB"
+echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
+echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
+
+# Apply the values with sysctl
+sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
+sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB

BIN=BIN
docs/exo-rounded.png


BIN=BIN
docs/exo-screenshot.png


+ 1 - 1
examples/astra/astra/ContentView.swift

@@ -148,7 +148,7 @@ struct ContentView: View {
     @State private var voiceActivityThreshold: Float = 0.40
     @State private var voiceActivityThreshold: Float = 0.40
     @State private var silenceTimeThreshold = 1.0
     @State private var silenceTimeThreshold = 1.0
     @State private var debugText = ""
     @State private var debugText = ""
-    @State private var apiEndpoint = "http://192.168.212.74:8000/v1/chat/completions"
+    @State private var apiEndpoint = "http://192.168.212.74:52415/v1/chat/completions"
     @State private var audioBuffer: [Float] = []
     @State private var audioBuffer: [Float] = []
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
     @State private var isInitialTranscription = true
     @State private var isInitialTranscription = true

+ 1 - 1
examples/chatgpt_api.sh

@@ -3,7 +3,7 @@
 # This works the same in a single-node set up and in a multi-node setup.
 # This works the same in a single-node set up and in a multi-node setup.
 # You need to start exo before running this by running `python3 main.py`.
 # You need to start exo before running this by running `python3 main.py`.
 
 
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 MODEL="llama-3.1-8b"
 MODEL="llama-3.1-8b"
 PROMPT="What is the meaning of exo?"
 PROMPT="What is the meaning of exo?"
 TEMPERATURE=0.7
 TEMPERATURE=0.7

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
+from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION

+ 76 - 56
exo/api/chatgpt_api.py

@@ -8,15 +8,16 @@ from typing import List, Literal, Union, Dict
 from aiohttp import web
 from aiohttp import web
 import aiohttp_cors
 import aiohttp_cors
 import traceback
 import traceback
+import os
+import signal
+import sys
 from exo import DEBUG, VERSION
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
-from exo.helpers import PrefixDict
-from exo.inference.shard import Shard
+from exo.helpers import PrefixDict, shutdown
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.orchestration import Node
-from exo.models import model_base_shards
-from typing import Callable
-
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
+from typing import Callable, Optional
 
 
 class Message:
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -27,6 +28,7 @@ class Message:
     return {"role": self.role, "content": self.content}
     return {"role": self.role, "content": self.content}
 
 
 
 
+
 class ChatCompletionRequest:
 class ChatCompletionRequest:
   def __init__(self, model: str, messages: List[Message], temperature: float):
   def __init__(self, model: str, messages: List[Message], temperature: float):
     self.model = model
     self.model = model
@@ -117,19 +119,11 @@ def remap_messages(messages: List[Message]) -> List[Message]:
 def build_prompt(tokenizer, _messages: List[Message]):
 def build_prompt(tokenizer, _messages: List[Message]):
   messages = remap_messages(_messages)
   messages = remap_messages(_messages)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
   prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
-  image_str = None
   for message in messages:
   for message in messages:
     if not isinstance(message.content, list):
     if not isinstance(message.content, list):
       continue
       continue
 
 
-    for content in message.content:
-      # note: we only support one image at a time right now. Multiple is possible. See: https://github.com/huggingface/transformers/blob/e68ec18ce224af879f22d904c7505a765fb77de3/docs/source/en/model_doc/llava.md?plain=1#L41
-      # follows the convention in https://platform.openai.com/docs/guides/vision
-      if isinstance(content, dict) and content.get("type", None) == "image":
-        image_str = content.get("image", None)
-        break
-
-  return prompt, image_str
+  return prompt
 
 
 
 
 def parse_message(data: dict):
 def parse_message(data: dict):
@@ -138,9 +132,9 @@ def parse_message(data: dict):
   return Message(data["role"], data["content"])
   return Message(data["role"], data["content"])
 
 
 
 
-def parse_chat_request(data: dict):
+def parse_chat_request(data: dict, default_model: str):
   return ChatCompletionRequest(
   return ChatCompletionRequest(
-    data.get("model", "llama-3.1-8b"),
+    data.get("model", default_model),
     [parse_message(msg) for msg in data["messages"]],
     [parse_message(msg) for msg in data["messages"]],
     data.get("temperature", 0.0),
     data.get("temperature", 0.0),
   )
   )
@@ -152,9 +146,8 @@ class PromptSession:
     self.timestamp = timestamp
     self.timestamp = timestamp
     self.prompt = prompt
     self.prompt = prompt
 
 
-
 class ChatGPTAPI:
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
     self.response_timeout = response_timeout
@@ -163,6 +156,8 @@ class ChatGPTAPI:
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
+    self.default_model = default_model or "llama-3.2-1b"
+
     cors = aiohttp_cors.setup(self.app)
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
     cors_options = aiohttp_cors.ResourceOptions(
       allow_credentials=True,
       allow_credentials=True,
@@ -176,15 +171,34 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
-    # Endpoint for download progress tracking
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
+    cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
+    cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
+    cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
 
 
-    self.static_dir = Path(__file__).parent.parent/"tinychat"
-    self.app.router.add_get("/", self.handle_root)
-    self.app.router.add_static("/", self.static_dir, name="static")
+    if "__compiled__" not in globals():
+      self.static_dir = Path(__file__).parent.parent/"tinychat"
+      self.app.router.add_get("/", self.handle_root)
+      self.app.router.add_static("/", self.static_dir, name="static")
 
 
-    # Add middleware to log every request
+    self.app.middlewares.append(self.timeout_middleware)
     self.app.middlewares.append(self.log_request)
     self.app.middlewares.append(self.log_request)
+  
+  async def handle_quit(self, request):
+    if DEBUG>=1: print("Received quit signal")
+    response = web.json_response({"detail": "Quit signal received"}, status=200)
+    await response.prepare(request)
+    await response.write_eof()
+    await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
+
+  async def timeout_middleware(self, app, handler):
+    async def middleware(request):
+      try:
+        return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
+      except asyncio.TimeoutError:
+        return web.json_response({"detail": "Request timed out"}, status=408)
+
+    return middleware
 
 
   async def log_request(self, app, handler):
   async def log_request(self, app, handler):
     async def middleware(request):
     async def middleware(request):
@@ -196,52 +210,58 @@ class ChatGPTAPI:
   async def handle_root(self, request):
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
     return web.FileResponse(self.static_dir/"index.html")
 
 
+  async def handle_healthcheck(self, request):
+    return web.json_response({"status": "ok"})
+
+  async def handle_model_support(self, request):
+    return web.json_response({
+      "model pool": {
+        model_name: pretty_name.get(model_name, model_name) 
+        for model_name in get_supported_models(self.node.topology_inference_engines_pool)
+      }
+    })
+  
   async def handle_get_models(self, request):
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True } for model_name, _ in model_base_shards.items()])
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 
 
   async def handle_post_chat_token_encode(self, request):
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
     data = await request.json()
-    shard = model_base_shards.get(data.get("model", "llama-3.1-8b"), {}).get(self.inference_engine_classname)
+    shard = build_base_shard(self.default_model, self.inference_engine_classname)
     messages = [parse_message(msg) for msg in data.get("messages", [])]
     messages = [parse_message(msg) for msg in data.get("messages", [])]
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
     return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
 
 
   async def handle_get_download_progress(self, request):
   async def handle_get_download_progress(self, request):
     progress_data = {}
     progress_data = {}
     for node_id, progress_event in self.node.node_download_progress.items():
     for node_id, progress_event in self.node.node_download_progress.items():
-        if isinstance(progress_event, RepoProgressEvent):
-            # Convert to dict if not already
-            progress_data[node_id] = progress_event.to_dict()
-        elif isinstance(progress_event, dict):
-            progress_data[node_id] = progress_event
-        else:
-            # Handle unexpected types
-            progress_data[node_id] = str(progress_event)
+      if isinstance(progress_event, RepoProgressEvent):
+        progress_data[node_id] = progress_event.to_dict()
+      else:
+        print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
     return web.json_response(progress_data)
     return web.json_response(progress_data)
 
 
-
   async def handle_post_chat_completions(self, request):
   async def handle_post_chat_completions(self, request):
     data = await request.json()
     data = await request.json()
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     stream = data.get("stream", False)
-    chat_request = parse_chat_request(data)
-    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-      chat_request.model = "llama-3.1-8b"
-    if not chat_request.model or chat_request.model not in model_base_shards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_base_shards.keys())}. Defaulting to llama-3.1-8b")
-      chat_request.model = "llama-3.1-8b"
-    shard = model_base_shards[chat_request.model].get(self.inference_engine_classname, None)
+    chat_request = parse_chat_request(data, self.default_model)
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
+      chat_request.model = self.default_model
+    if not chat_request.model or chat_request.model not in model_cards:
+      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      chat_request.model = self.default_model
+    shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     if not shard:
     if not shard:
-      supported_models = [model for model, engines in model_base_shards.items() if self.inference_engine_classname in engines]
+      supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
       return web.json_response(
       return web.json_response(
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         {"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
         status=400,
         status=400,
       )
       )
 
 
-    tokenizer = await resolve_tokenizer(shard.model_id)
+    tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
     if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
 
-    prompt, image_str = build_prompt(tokenizer, chat_request.messages)
+    prompt = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
     if self.on_chat_completion_request:
     if self.on_chat_completion_request:
       try:
       try:
@@ -264,14 +284,11 @@ class ChatGPTAPI:
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback_id = f"chatgpt-api-wait-response-{request_id}"
     callback = self.node.on_token.register(callback_id)
     callback = self.node.on_token.register(callback_id)
 
 
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
-    try:
-      await self.node.process_prompt(shard, prompt, image_str, request_id=request_id)
-    except Exception as e:
-      if DEBUG >= 2: traceback.print_exc()
-      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
+    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
 
 
     try:
     try:
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
+
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 
 
       if stream:
       if stream:
@@ -285,9 +302,9 @@ class ChatGPTAPI:
         )
         )
         await response.prepare(request)
         await response.prepare(request)
 
 
-        async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
-          prev_last_tokens_len = self.prev_token_lens.get(request_id, 0)
-          self.prev_token_lens[request_id] = max(prev_last_tokens_len, len(tokens))
+        async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
+          prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
+          self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
           new_tokens = tokens[prev_last_tokens_len:]
           new_tokens = tokens[prev_last_tokens_len:]
           finish_reason = None
           finish_reason = None
           eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
           eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
@@ -317,7 +334,7 @@ class ChatGPTAPI:
             if DEBUG >= 2: traceback.print_exc()
             if DEBUG >= 2: traceback.print_exc()
 
 
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
         def on_result(_request_id: str, tokens: List[int], is_finished: bool):
-          self.stream_tasks[request_id] = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
 
 
           return _request_id == request_id and is_finished
           return _request_id == request_id and is_finished
 
 
@@ -346,11 +363,14 @@ class ChatGPTAPI:
         return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
         return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
     except asyncio.TimeoutError:
     except asyncio.TimeoutError:
       return web.json_response({"detail": "Response generation timed out"}, status=408)
       return web.json_response({"detail": "Response generation timed out"}, status=408)
+    except Exception as e:
+      if DEBUG >= 2: traceback.print_exc()
+      return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
     finally:
     finally:
       deregistered_callback = self.node.on_token.deregister(callback_id)
       deregistered_callback = self.node.on_token.deregister(callback_id)
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
       if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
 
-  async def run(self, host: str = "0.0.0.0", port: int = 8000):
+  async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     runner = web.AppRunner(self.app)
     await runner.setup()
     await runner.setup()
     site = web.TCPSite(runner, host, port)
     site = web.TCPSite(runner, host, port)

+ 39 - 8
exo/download/hf/hf_helpers.py

@@ -1,7 +1,11 @@
+import aiofiles.os as aios
+from typing import Union
 import asyncio
 import asyncio
 import aiohttp
 import aiohttp
 import json
 import json
 import os
 import os
+import sys
+import shutil
 from urllib.parse import urljoin
 from urllib.parse import urljoin
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
@@ -9,15 +13,13 @@ from fnmatch import fnmatch
 from pathlib import Path
 from pathlib import Path
 from typing import Generator, Iterable, TypeVar, TypedDict
 from typing import Generator, Iterable, TypeVar, TypedDict
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
 from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
-from exo.helpers import DEBUG
+from exo.helpers import DEBUG, is_frozen
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 import aiofiles
 import aiofiles
-from aiofiles import os as aios
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
-
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
 async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_dir = get_repo_root(repo_id)/"refs"
   refs_file = refs_dir/revision
   refs_file = refs_dir/revision
@@ -70,8 +72,10 @@ def _add_wildcard_to_directories(pattern: str) -> str:
     return pattern + "*"
     return pattern + "*"
   return pattern
   return pattern
 
 
+
 def get_hf_endpoint() -> str:
 def get_hf_endpoint() -> str:
-    return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+  return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+
 
 
 def get_hf_home() -> Path:
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
   """Get the Hugging Face home directory."""
@@ -97,10 +101,27 @@ async def get_auth_headers():
 
 
 def get_repo_root(repo_id: str) -> Path:
 def get_repo_root(repo_id: str) -> Path:
   """Get the root directory for a given repo ID in the Hugging Face cache."""
   """Get the root directory for a given repo ID in the Hugging Face cache."""
-  sanitized_repo_id = repo_id.replace("/", "--")
+  sanitized_repo_id = str(repo_id).replace("/", "--")
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
   return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
 
 
-
+async def move_models_to_hf(seed_dir: Union[str, Path]):
+  """Move model in resources folder of app to .cache/huggingface/hub"""
+  source_dir = Path(seed_dir)
+  dest_dir = get_hf_home()/"hub"
+  await aios.makedirs(dest_dir, exist_ok=True)  
+  for path in source_dir.iterdir():
+    if path.is_dir() and path.name.startswith("models--"):
+      dest_path = dest_dir / path.name
+      if await aios.path.exists(dest_path):
+        print('Skipping moving model to .cache directory')
+      else:
+        try:
+          await aios.rename(str(path), str(dest_path))
+        except Exception as e:
+          print(f'Error moving model to .cache: {e}')
+    
+    
+    
 async def fetch_file_list(session, repo_id, revision, path=""):
 async def fetch_file_list(session, repo_id, revision, path=""):
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
   url = f"{api_url}/{path}" if path else api_url
   url = f"{api_url}/{path}" if path else api_url
@@ -394,7 +415,7 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
 
 
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-  default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"])
+  default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   shard_specific_patterns = set()
   shard_specific_patterns = set()
   if weight_map:
   if weight_map:
     for tensor_name, filename in weight_map.items():
     for tensor_name, filename in weight_map.items():
@@ -407,6 +428,16 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     elif shard.is_last_layer():
     elif shard.is_last_layer():
       shard_specific_patterns.add(sorted_file_names[-1])
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
   else:
-    shard_specific_patterns = set("*.safetensors")
+    shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
   return list(default_patterns | shard_specific_patterns)
+
+async def has_hf_home_read_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.R_OK)
+  except OSError: return False
+
+async def has_hf_home_write_access() -> bool:
+  hf_home = get_hf_home()
+  try: return await aios.access(hf_home, os.W_OK)
+  except OSError: return False

+ 8 - 6
exo/download/hf/hf_shard_download.py

@@ -7,6 +7,7 @@ from exo.download.shard_download import ShardDownloader
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.download_progress import RepoProgressEvent
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.download.hf.hf_helpers import download_repo_files, RepoProgressEvent, get_weight_map, get_allow_patterns, get_repo_root
 from exo.helpers import AsyncCallbackSystem, DEBUG
 from exo.helpers import AsyncCallbackSystem, DEBUG
+from exo.models import model_cards, get_repo
 
 
 
 
 class HFShardDownloader(ShardDownloader):
 class HFShardDownloader(ShardDownloader):
@@ -17,11 +18,12 @@ class HFShardDownloader(ShardDownloader):
     self.completed_downloads: Dict[Shard, Path] = {}
     self.completed_downloads: Dict[Shard, Path] = {}
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
     self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
 
 
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    repo_name = get_repo(shard.model_id, inference_engine_name)
     if shard in self.completed_downloads:
     if shard in self.completed_downloads:
       return self.completed_downloads[shard]
       return self.completed_downloads[shard]
     if self.quick_check:
     if self.quick_check:
-      repo_root = get_repo_root(shard.model_id)
+      repo_root = get_repo_root(repo_name)
       snapshots_dir = repo_root/"snapshots"
       snapshots_dir = repo_root/"snapshots"
       if snapshots_dir.exists():
       if snapshots_dir.exists():
         visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
         visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
@@ -51,7 +53,7 @@ class HFShardDownloader(ShardDownloader):
     self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
     self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
 
 
     # Start new download
     # Start new download
-    download_task = asyncio.create_task(self._download_shard(shard))
+    download_task = asyncio.create_task(self._download_shard(shard, repo_name))
     self.active_downloads[shard] = download_task
     self.active_downloads[shard] = download_task
     try:
     try:
       path = await download_task
       path = await download_task
@@ -63,14 +65,14 @@ class HFShardDownloader(ShardDownloader):
       if shard in self.active_downloads:
       if shard in self.active_downloads:
         self.active_downloads.pop(shard)
         self.active_downloads.pop(shard)
 
 
-  async def _download_shard(self, shard: Shard) -> Path:
+  async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
     async def wrapped_progress_callback(event: RepoProgressEvent):
     async def wrapped_progress_callback(event: RepoProgressEvent):
       self._on_progress.trigger_all(shard, event)
       self._on_progress.trigger_all(shard, event)
 
 
-    weight_map = await get_weight_map(shard.model_id)
+    weight_map = await get_weight_map(repo_name)
     allow_patterns = get_allow_patterns(weight_map, shard)
     allow_patterns = get_allow_patterns(weight_map, shard)
 
 
-    return await download_repo_files(repo_id=shard.model_id, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
+    return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
 
 
   @property
   @property
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:

+ 11 - 1
exo/download/shard_download.py

@@ -8,7 +8,7 @@ from exo.helpers import AsyncCallbackSystem
 
 
 class ShardDownloader(ABC):
 class ShardDownloader(ABC):
   @abstractmethod
   @abstractmethod
-  async def ensure_shard(self, shard: Shard) -> Path:
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
     """
     """
         Ensures that the shard is downloaded.
         Ensures that the shard is downloaded.
         Does not allow multiple overlapping downloads at once.
         Does not allow multiple overlapping downloads at once.
@@ -17,6 +17,7 @@ class ShardDownloader(ABC):
 
 
         Args:
         Args:
             shard (Shard): The shard to download.
             shard (Shard): The shard to download.
+            inference_engine_name (str): The inference engine used on the node hosting the shard
         """
         """
     pass
     pass
 
 
@@ -24,3 +25,12 @@ class ShardDownloader(ABC):
   @abstractmethod
   @abstractmethod
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
     pass
+
+
+class NoopShardDownloader(ShardDownloader):
+  async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
+    return Path("/tmp/noop_shard")
+
+  @property
+  def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
+    return AsyncCallbackSystem()

+ 20 - 1
exo/helpers.py

@@ -1,4 +1,5 @@
 import os
 import os
+import sys
 import asyncio
 import asyncio
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
 import socket
 import socket
@@ -170,7 +171,7 @@ def is_valid_uuid(val):
 
 
 
 
 def get_or_create_node_id():
 def get_or_create_node_id():
-  NODE_ID_FILE = Path(tempfile.gettempdir()) / ".exo_node_id"
+  NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
   try:
   try:
     if NODE_ID_FILE.is_file():
     if NODE_ID_FILE.is_file():
       with open(NODE_ID_FILE, "r") as f:
       with open(NODE_ID_FILE, "r") as f:
@@ -234,3 +235,21 @@ def get_all_ip_addresses():
   except:
   except:
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
     return ["localhost"]
     return ["localhost"]
+
+
+async def shutdown(signal, loop, server):
+  """Gracefully shutdown the server and close the asyncio loop."""
+  print(f"Received exit signal {signal.name}...")
+  print("Thank you for using exo.")
+  print_yellow_exo()
+  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
+  [task.cancel() for task in server_tasks]
+  print(f"Cancelling {len(server_tasks)} outstanding tasks")
+  await asyncio.gather(*server_tasks, return_exceptions=True)
+  await server.stop()
+
+
+def is_frozen():
+  return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
+    or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
+    or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)

+ 11 - 12
exo/inference/debug_inference_engine.py

@@ -13,32 +13,31 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
   _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
 
 
   prompt = "In a single word only, what is the last name of the president of the United States? "
   prompt = "In a single word only, what is the last name of the president of the United States? "
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     "A",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
-    input_data=resp_full,
-    inference_state=inference_state_full,
+    input_data=token_full,
   )
   )
 
 
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp1,
     input_data=resp1,
-    inference_state=inference_state_1,
   )
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  token2 = await inference_engine_2.sample(resp2)
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
-    input_data=resp2,
-    inference_state=inference_state_2,
+    input_data=token2,
   )
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     "B",
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
     input_data=resp3,
     input_data=resp3,
-    inference_state=inference_state_3,
   )
   )
 
 
   print(f"{resp2=}")
   print(f"{resp2=}")

+ 43 - 0
exo/inference/dummy_inference_engine.py

@@ -0,0 +1,43 @@
+from typing import Optional, Tuple, TYPE_CHECKING
+import numpy as np
+import random
+import string
+import asyncio
+import json
+from exo.inference.inference_engine import InferenceEngine
+from exo.inference.shard import Shard
+def random_string(length: int):
+  return ''.join([random.choice(string.ascii_lowercase) for i in range(length)])
+  
+
+class DummyInferenceEngine(InferenceEngine):
+  def __init__(self):
+    self.shard = None
+    self.vocab_size = 1000
+    self.hidden_size = 256
+    self.eos_token_id = 0
+    self.latency_mean = 0.1
+    self.latency_stddev = 0.02
+
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size, size=(1, len(prompt.split())))
+  
+  async def sample(self, x: np.ndarray) -> np.ndarray:
+    return np.random.randint(1, self.vocab_size)
+
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
+    return ' '.join([random_string(np.random.randint(1, 34)) for token in tokens])
+
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    await self.ensure_shard(shard)
+    sequence_length = input_data.shape[0 if self.shard.is_first_layer() else 1]
+    output = np.random.random(size=(1, sequence_length, self.vocab_size if self.shard.is_last_layer() else self.hidden_size))
+    return output
+
+  async def ensure_shard(self, shard: Shard):
+    if self.shard == shard:
+      return
+    # Simulate shard loading without making any API calls
+    await asyncio.sleep(0.1)  # Simulate a short delay
+    self.shard = shard
+    print(f"DummyInferenceEngine: Simulated loading of shard {shard.model_id}")

+ 28 - 4
exo/inference/inference_engine.py

@@ -1,5 +1,6 @@
 import numpy as np
 import numpy as np
 import os
 import os
+from exo.helpers import DEBUG  # Make sure to import DEBUG
 
 
 from typing import Tuple, Optional
 from typing import Tuple, Optional
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
@@ -8,15 +9,36 @@ from .shard import Shard
 
 
 class InferenceEngine(ABC):
 class InferenceEngine(ABC):
   @abstractmethod
   @abstractmethod
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+    pass
+  
+  @abstractmethod
+  async def sample(self, x: np.ndarray) -> np.ndarray:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
     pass
     pass
 
 
+  @abstractmethod
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    pass
+  
+  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
+    tokens = await self.encode(shard, prompt)
+    x = tokens.reshape(1, -1)
+    output_data = await self.infer_tensor(request_id, shard, x)
+    return output_data 
+
+inference_engine_classes = {
+  "mlx": "MLXDynamicShardInferenceEngine",
+  "tinygrad": "TinygradDynamicShardInferenceEngine",
+  "dummy": "DummyInferenceEngine",
+}
 
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
+  if DEBUG >= 2:
+    print(f"get_inference_engine called with: {inference_engine_name}")
   if inference_engine_name == "mlx":
   if inference_engine_name == "mlx":
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
     from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
 
 
@@ -27,5 +49,7 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDow
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
     tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
 
 
     return TinygradDynamicShardInferenceEngine(shard_downloader)
     return TinygradDynamicShardInferenceEngine(shard_downloader)
-  else:
-    raise ValueError(f"Inference engine {inference_engine_name} not supported")
+  elif inference_engine_name == "dummy":
+    from exo.inference.dummy_inference_engine import DummyInferenceEngine
+    return DummyInferenceEngine()
+  raise ValueError(f"Unsupported inference engine: {inference_engine_name}")

+ 1 - 1
exo/inference/mlx/models/base.py

@@ -1,7 +1,7 @@
 from typing import Optional
 from typing import Optional
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 
 
 
 
 class IdentityBlock(nn.Module):
 class IdentityBlock(nn.Module):

+ 1 - 1
exo/inference/mlx/models/deepseek_v2.py

@@ -4,7 +4,7 @@ from typing import Optional
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 
 
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
 from .base import IdentityBlock
 from .base import IdentityBlock
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard

+ 118 - 0
exo/inference/mlx/models/gemma2.py

@@ -0,0 +1,118 @@
+from dataclasses import dataclass, field
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from mlx_lm.models.base import create_attention_mask
+from mlx_lm.models.gemma2 import TransformerBlock, ModelArgs, RMSNorm
+
+from ...shard import Shard
+from .base import IdentityBlock
+
+
+@dataclass
+class ModelArgs(ModelArgs):
+  shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
+
+  def __post_init__(self):
+    if isinstance(self.shard, Shard):
+      return
+    if not isinstance(self.shard, dict):
+      raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
+
+    self.shard = Shard(**self.shard)
+
+
+class GemmaModel(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.vocab_size = args.vocab_size
+    self.num_hidden_layers = args.num_hidden_layers
+    assert self.vocab_size > 0
+    if args.shard.is_first_layer() or args.shard.is_last_layer():
+      self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+    self.layers = []
+    for i in range(self.num_hidden_layers):
+      if args.shard.start_layer <= i <= args.shard.end_layer:
+        self.layers.append(TransformerBlock(args=args))
+      else:
+        self.layers.append(IdentityBlock())
+    if args.shard.is_last_layer():
+      self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    if self.args.shard.is_first_layer():
+      h = self.embed_tokens(inputs)
+      h = h * (self.args.hidden_size**0.5)
+    else:
+      h = inputs
+
+    mask = None
+    if h.ndim > 1 and h.shape[1] > 1:
+      mask = create_attention_mask(h, cache)
+
+    if cache is None:
+      cache = [None]*len(self.layers)
+
+    for layer, c in zip(self.layers, cache):
+      h = layer(h, mask, cache=c)
+
+    if self.args.shard.is_last_layer():
+      h = self.norm(h)
+    return h
+
+
+class Model(nn.Module):
+  def __init__(self, args: ModelArgs):
+    super().__init__()
+    self.args = args
+    self.model_type = args.model_type
+    self.model = GemmaModel(args)
+    if args.shard.is_last_layer():
+      self.final_logit_softcapping = args.final_logit_softcapping
+
+  def __call__(
+    self,
+    inputs: mx.array,
+    cache=None,
+  ):
+    out = self.model(inputs, cache)
+    if self.args.shard.is_last_layer():
+      out = self.model.embed_tokens.as_linear(out)
+      out = mx.tanh(out / self.final_logit_softcapping)
+      out = out * self.final_logit_softcapping
+    return out
+
+  def sanitize(self, weights):
+    shard_state_dict = {}
+
+    for key, value in weights.items():
+      if "self_attn.rotary_emb.inv_freq" in key:
+        continue
+      if key.startswith('model.layers.'):
+        layer_num = int(key.split('.')[2])
+        if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
+          shard_state_dict[key] = value
+      elif (self.args.shard.is_first_layer() or self.args.shard.is_last_layer()) and key.startswith('model.embed_tokens'):
+        shard_state_dict[key] = value
+      elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
+        shard_state_dict[key] = value
+
+    return shard_state_dict
+
+  @property
+  def layers(self):
+    return self.model.layers
+
+  @property
+  def head_dim(self):
+    return self.args.head_dim
+
+  @property
+  def n_kv_heads(self):
+    return self.args.num_key_value_heads

+ 4 - 4
exo/inference/mlx/models/llama.py

@@ -32,15 +32,15 @@ class LlamaModel(nn.Module):
     self.vocab_size = args.vocab_size
     self.vocab_size = args.vocab_size
     self.num_hidden_layers = args.num_hidden_layers
     self.num_hidden_layers = args.num_hidden_layers
     assert self.vocab_size > 0
     assert self.vocab_size > 0
-    if self.args.shard.is_first_layer():
+    if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings):
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
       self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
     self.layers = []
     self.layers = []
     for i in range(self.num_hidden_layers):
     for i in range(self.num_hidden_layers):
-      if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
+      if args.shard.start_layer <= i <= args.shard.end_layer:
         self.layers.append(TransformerBlock(args=args))
         self.layers.append(TransformerBlock(args=args))
       else:
       else:
         self.layers.append(IdentityBlock())
         self.layers.append(IdentityBlock())
-    if self.args.shard.is_last_layer():
+    if args.shard.is_last_layer():
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
       self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
 
 
   def __call__(
   def __call__(
@@ -74,7 +74,7 @@ class Model(nn.Module):
     self.args = args
     self.args = args
     self.model_type = args.model_type
     self.model_type = args.model_type
     self.model = LlamaModel(args)
     self.model = LlamaModel(args)
-    if self.args.shard.is_last_layer():
+    if args.shard.is_last_layer():
       if not args.tie_word_embeddings:
       if not args.tie_word_embeddings:
         self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
         self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
 
 

+ 2 - 1
exo/inference/mlx/models/qwen2.py

@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
 
 
     self.shard = Shard(**self.shard)
     self.shard = Shard(**self.shard)
 
 
+
 class Qwen2Model(nn.Module):
 class Qwen2Model(nn.Module):
   def __init__(self, args: ModelArgs):
   def __init__(self, args: ModelArgs):
     super().__init__()
     super().__init__()
@@ -57,7 +58,7 @@ class Qwen2Model(nn.Module):
       mask = create_attention_mask(h, cache)
       mask = create_attention_mask(h, cache)
 
 
     if cache is None:
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
 
     for layer, c in zip(self.layers, cache):
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, c)
       h = layer(h, mask, c)

+ 47 - 20
exo/inference/mlx/sharded_inference_engine.py

@@ -1,13 +1,35 @@
 import numpy as np
 import numpy as np
 import mlx.core as mx
 import mlx.core as mx
+import mlx.nn as nn
 from ..inference_engine import InferenceEngine
 from ..inference_engine import InferenceEngine
-from .sharded_model import StatefulShardedModel
+from .stateful_model import StatefulModel
 from .sharded_utils import load_shard, get_image_from_str
 from .sharded_utils import load_shard, get_image_from_str
 from ..shard import Shard
 from ..shard import Shard
-from typing import Optional
+from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 import asyncio
 import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+def sample_logits(
+  logits: mx.array,
+  temp: float = 0.0,
+  top_p: float = 1.0,
+  logit_bias: Optional[Dict[int, float]] = None
+) -> Tuple[mx.array, float]:
+  if logit_bias:
+    indices = mx.array(list(logit_bias.keys()))
+    values = mx.array(list(logit_bias.values()))
+    logits[:, indices] += values
+
+  if temp == 0:
+    token = mx.argmax(logits, axis=-1)
+  else:
+    if top_p > 0 and top_p < 1.0:
+      token = top_p_sampling(logits, top_p, temp)
+    else:
+      token = mx.random.categorical(logits*(1/temp))
+
+  return token
 
 
 class MLXDynamicShardInferenceEngine(InferenceEngine):
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
@@ -15,34 +37,39 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.executor = ThreadPoolExecutor(max_workers=1)
 
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
+    y = mx.array(x)
+    logits = y[:, -1, :]
+    out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
+    return out
+
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
-    if image_str:
-      image = await get_image_from_str(image_str)
-      inputs = await loop.run_in_executor(self.executor, self.tokenizer, prompt, image, return_tensors="np")
-      pixel_values = mx.array(inputs["pixel_values"])
-      input_ids = mx.array(inputs["input_ids"])
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids, pixel_values))
-    else:
-      input_ids = mx.array(await loop.run_in_executor(self.executor, self.tokenizer.encode, prompt))
-      output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, input_ids))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return np.array(tokens)
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
+    return tokens
+    
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.stateful_sharded_model.step, request_id, mx.array(input_data)))
-    return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
+    output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
+    return output_data
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
 
     if self.shard != shard:
     if self.shard != shard:
       loop = asyncio.get_running_loop()
       loop = asyncio.get_running_loop()
-      def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
+
+      def load_shard_wrapper():
+        return asyncio.run(load_shard(model_path, shard))
+
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
-      self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 0 - 86
exo/inference/mlx/sharded_model.py

@@ -1,86 +0,0 @@
-from typing import Dict, Generator, Optional, Tuple
-from collections import OrderedDict
-
-import mlx.core as mx
-import mlx.nn as nn
-from mlx_lm.models.base import KVCache, RotatingKVCache
-from mlx_lm.sample_utils import top_p_sampling
-
-from ..shard import Shard
-
-# TODO: support a speculative model so we can parallelise compute across devices
-class StatefulShardedModel:
-  def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
-    self.shard = shard
-    self.model = model
-    self.max_kv_size = max_kv_size
-    self.max_caches = max_caches
-    self.caches = OrderedDict()
-
-  def step(
-    self,
-    request_id: str,
-    x,
-    pixel_values=None,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    def sample(logits: mx.array) -> Tuple[mx.array, float]:
-      if logit_bias:
-        indices = mx.array(list(logit_bias.keys()))
-        values = mx.array(list(logit_bias.values()))
-        logits[:, indices] += values
-
-      if temp == 0:
-        token = mx.argmax(logits, axis=-1)
-      else:
-        if top_p > 0 and top_p < 1.0:
-          token = top_p_sampling(logits, top_p, temp)
-        else:
-          token = mx.random.categorical(logits*(1/temp))
-
-      return token
-
-    y = x
-
-    if request_id not in self.caches:
-      self.init_cache(request_id)
-    else:
-      self.caches.move_to_end(request_id)
-
-    cache = self.caches[request_id]
-
-    if pixel_values is None:
-      output = self.model(y[None] if self.shard.is_first_layer() else y, cache=cache)
-    else:
-      output = self.model(y, pixel_values=pixel_values, cache=cache)
-
-    if self.shard.is_last_layer():
-      logits = output[:, -1, :]
-      y = sample(logits)
-      return y
-    else:
-      return output
-
-  def __call__(
-    self,
-    request_id: str,
-    x,
-    temp: float = 0.0,
-    top_p: float = 1.0,
-    logit_bias: Optional[Dict[int, float]] = None,
-  ) -> Generator[Tuple[mx.array, mx.array], None, None]:
-    return self.step(request_id, x, temp=temp, top_p=top_p, logit_bias=logit_bias)
-
-  def init_cache(self, request_id: str):
-    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
-    if self.max_kv_size is not None:
-      cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
-    else:
-      cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
-
-    if len(self.caches) >= self.max_caches:
-      self.caches.popitem(last=False)
-
-    self.caches[request_id] = cache

+ 14 - 8
exo/inference/mlx/sharded_utils.py

@@ -12,15 +12,16 @@ from typing import Optional, Tuple, Union, List, Callable
 from PIL import Image
 from PIL import Image
 from io import BytesIO
 from io import BytesIO
 import base64
 import base64
+import traceback
 
 
 import mlx.core as mx
 import mlx.core as mx
 import mlx.nn as nn
 import mlx.nn as nn
 from transformers import AutoProcessor
 from transformers import AutoProcessor
 
 
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
 from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
-from mlx_lm.tuner.utils import apply_lora_layers
 
 
 from exo import DEBUG
 from exo import DEBUG
+from exo.inference.tokenizers import resolve_tokenizer
 from ..shard import Shard
 from ..shard import Shard
 
 
 
 
@@ -53,6 +54,7 @@ def _get_classes(config: dict):
   except ImportError:
   except ImportError:
     msg = f"Model type {model_type} not supported."
     msg = f"Model type {model_type} not supported."
     logging.error(msg)
     logging.error(msg)
+    traceback.print_exc()
     raise ValueError(msg)
     raise ValueError(msg)
 
 
   return arch.Model, arch.ModelArgs
   return arch.Model, arch.ModelArgs
@@ -67,7 +69,6 @@ def load_config(model_path: Path) -> dict:
     raise
     raise
   return config
   return config
 
 
-
 def load_model_shard(
 def load_model_shard(
   model_path: Path,
   model_path: Path,
   shard: Shard,
   shard: Shard,
@@ -130,8 +131,17 @@ def load_model_shard(
 
 
   model_class, model_args_class = _get_classes(config=config)
   model_class, model_args_class = _get_classes(config=config)
 
 
+  class ShardedModel(model_class):
+    def __init__(self, args):
+      super().__init__(args)
+      self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
+
+    def __call__(self, x, *args, **kwargs):
+      y = super().__call__(x, *args, **kwargs)
+      return y
+
   model_args = model_args_class.from_dict(config)
   model_args = model_args_class.from_dict(config)
-  model = model_class(model_args)
+  model = ShardedModel(model_args)
 
 
   if hasattr(model, "sanitize"):
   if hasattr(model, "sanitize"):
     weights = model.sanitize(weights)
     weights = model.sanitize(weights)
@@ -157,7 +167,6 @@ def load_model_shard(
   model.eval()
   model.eval()
   return model
   return model
 
 
-
 async def load_shard(
 async def load_shard(
   model_path: str,
   model_path: str,
   shard: Shard,
   shard: Shard,
@@ -167,9 +176,6 @@ async def load_shard(
   lazy: bool = False,
   lazy: bool = False,
 ) -> Tuple[nn.Module, TokenizerWrapper]:
 ) -> Tuple[nn.Module, TokenizerWrapper]:
   model = load_model_shard(model_path, shard, lazy, model_config)
   model = load_model_shard(model_path, shard, lazy, model_config)
-  if adapter_path is not None:
-    model = apply_lora_layers(model, adapter_path)
-    model.eval()
 
 
   # TODO: figure out a generic solution
   # TODO: figure out a generic solution
   if model.model_type == "llava":
   if model.model_type == "llava":
@@ -178,7 +184,7 @@ async def load_shard(
     processor.encode = processor.tokenizer.encode
     processor.encode = processor.tokenizer.encode
     return model, processor
     return model, processor
   else:
   else:
-    tokenizer = load_tokenizer(model_path, tokenizer_config)
+    tokenizer = await resolve_tokenizer(model_path)
     return model, tokenizer
     return model, tokenizer
 
 
 
 

+ 42 - 0
exo/inference/mlx/stateful_model.py

@@ -0,0 +1,42 @@
+from typing import Dict, Tuple
+from collections import OrderedDict
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.cache import make_prompt_cache
+
+from ..shard import Shard
+
+class StatefulModel(nn.Module):
+  def __init__(self, model, max_kv_size: int = 1024, max_caches: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_kv_size = max_kv_size
+    self.max_caches = max_caches
+    self.caches = OrderedDict()
+  
+  def init_cache(self, request_id: str):
+    kv_heads = ([self.model.n_kv_heads]*len(self.model.layers) if isinstance(self.model.n_kv_heads, int) else self.model.n_kv_heads)
+    # if self.max_kv_size is not None:
+      # cache = [RotatingKVCache(self.model.head_dim, n, max_size=self.max_kv_size, keep=4) for n in kv_heads]
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    # else:
+      # cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+    cache = make_prompt_cache(self.model)
+
+    if len(self.caches) >= self.max_caches:
+      self.caches.popitem(last=False)
+
+    self.caches[request_id] = cache
+
+  def __call__(self, x, request_id: str):
+    if request_id not in self.caches:
+      self.init_cache(request_id)
+    else:
+      self.caches.move_to_end(request_id)
+
+    cache = self.caches[request_id]
+
+    y = self.model(x, cache=cache)
+    return y
+    

+ 4 - 4
exo/inference/mlx/test_sharded_llama.py

@@ -1,5 +1,5 @@
 import mlx.core as mx
 import mlx.core as mx
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 
@@ -12,9 +12,9 @@ full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Ins
 model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
 model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
 model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
 model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
 
 
-full = StatefulShardedModel(shard_full, full_model_shard)
-m1 = StatefulShardedModel(shard1, model_shard1)
-m2 = StatefulShardedModel(shard2, model_shard2)
+full = StatefulModel(shard_full, full_model_shard)
+m1 = StatefulModel(shard1, model_shard1)
+m2 = StatefulModel(shard2, model_shard2)
 
 
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt_tokens = mx.array(full_tokenizer.encode(prompt))
 prompt_tokens = mx.array(full_tokenizer.encode(prompt))

+ 2 - 2
exo/inference/mlx/test_sharded_llava.py

@@ -5,9 +5,9 @@ from PIL import Image
 from io import BytesIO
 from io import BytesIO
 
 
 import mlx.core as mx
 import mlx.core as mx
-from mlx_lm.models.base import KVCache
+from mlx_lm.models.cache import KVCache
 
 
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.stateful_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 
 

+ 53 - 0
exo/inference/test_dummy_inference_engine.py

@@ -0,0 +1,53 @@
+import pytest
+import json
+import numpy as np
+from exo.inference.dummy_inference_engine import DummyInferenceEngine
+from exo.inference.shard import Shard
+
+
+class MockShardDownloader:
+  async def ensure_shard(self, shard):
+    pass
+
+
+@pytest.mark.asyncio
+async def test_dummy_inference_specific():
+  engine = DummyInferenceEngine(MockShardDownloader())
+  test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+  test_prompt = "This is a test prompt"
+
+  result = await engine.infer_prompt("test_request", test_shard, test_prompt)
+
+  print(f"Inference result shape: {result.shape}")
+
+  assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
+
+
+@pytest.mark.asyncio
+async def test_dummy_inference_engine():
+  # Initialize the DummyInferenceEngine
+  engine = DummyInferenceEngine(MockShardDownloader())
+
+  # Create a test shard
+  shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+
+  # Test infer_prompt
+  output = await engine.infer_prompt("test_id", shard, "Test prompt")
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+
+  # Test infer_tensor
+  input_tensor = np.array([[1, 2, 3]])
+  output = await engine.infer_tensor("test_id", shard, input_tensor)
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+
+  print("All tests passed!")
+
+
+if __name__ == "__main__":
+  import asyncio
+  asyncio.run(test_dummy_inference_engine())
+  asyncio.run(test_dummy_inference_specific())

+ 20 - 28
exo/inference/test_inference_engine.py

@@ -9,46 +9,42 @@ import numpy as np
 
 
 
 
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
-async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
+async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
   prompt = "In a single word only, what is the last name of the current president of the USA?"
   prompt = "In a single word only, what is the last name of the current president of the USA?"
-  resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-  next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
+  resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
+  token_full = await inference_engine_1.sample(resp_full)
+  token_full = token_full.reshape(1, -1)
+  next_resp_full = await inference_engine_1.infer_tensor(
     "A",
     "A",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
-    input_data=resp_full,
-    inference_state=inference_state_full,
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
+    input_data=token_full,
   )
   )
 
 
-  pp = 15
-  resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32), prompt=prompt)
-  resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
+  pp = n_layers // 2
+  resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
+  resp2 = await inference_engine_2.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp1,
     input_data=resp1,
-    inference_state=inference_state_1,
   )
   )
-  resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
+  tokens2 = await inference_engine_1.sample(resp2)
+  tokens2 = tokens2.reshape(1, -1)
+  resp3 = await inference_engine_1.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=32),
-    input_data=resp2,
-    inference_state=inference_state_2,
+    shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
+    input_data=tokens2,
   )
   )
-  resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
+  resp4 = await inference_engine_2.infer_tensor(
     "B",
     "B",
-    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=31, n_layers=32),
+    shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
     input_data=resp3,
     input_data=resp3,
-    inference_state=inference_state_3,
   )
   )
 
 
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(resp_full, resp2)
   assert np.array_equal(next_resp_full, resp4)
   assert np.array_equal(next_resp_full, resp4)
 
 
 
 
-asyncio.run(test_inference_engine(
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-))
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
 
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
   import tinygrad
@@ -56,9 +52,5 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   asyncio.run(
   asyncio.run(
-    test_inference_engine(
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-    )
+    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
   )
   )

+ 34 - 36
exo/inference/tinygrad/inference.py

@@ -1,17 +1,17 @@
 from pathlib import Path
 from pathlib import Path
 import json
 import json
 import os
 import os
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
+from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
 from tinygrad.nn.state import load_state_dict
 from tinygrad import Tensor, nn, Context
 from tinygrad import Tensor, nn, Context
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.inference_engine import InferenceEngine
-from typing import Optional, Tuple
 import numpy as np
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
+from .stateful_model import StatefulModel
 import asyncio
 import asyncio
 
 
 Tensor.no_grad = True
 Tensor.no_grad = True
@@ -22,7 +22,17 @@ TOP_P = 0.9
 ALPHA_F = 0.1
 ALPHA_F = 0.1
 ALPHA_P = 0.0
 ALPHA_P = 0.0
 MODEL_PARAMS = {
 MODEL_PARAMS = {
-  "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
+  "1B": {
+    "args": {
+      "dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "3B": {
+    "args": {
+      "dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
+      "rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
+    }, "files": 1
+  }, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
   "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
   "70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
 }
 }
 
 
@@ -30,8 +40,7 @@ MODEL_PARAMS = {
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   # build model
   linear = nn.Linear
   linear = nn.Linear
-  with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
+  model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
 
   # load weights
   # load weights
   if model_path.is_dir():
   if model_path.is_dir():
@@ -48,54 +57,43 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
     load_state_dict(model, weights, strict=False, consume=False)  # consume=True
   return model
   return model
 
 
-
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
 class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
     self.executor = ThreadPoolExecutor(max_workers=1)
     self.executor = ThreadPoolExecutor(max_workers=1)
 
 
-  async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
-    await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
-
-    if h.shape == (1,):
-      start_pos += len(toks)
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      n_captured_toks = len(toks)
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+  async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
+    logits = x[:, -1, :]
+    def sample_wrapper():
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
 
-  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+  async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
-    start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
-    n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
-
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
+    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
+  
+  async def decode(self, shard: Shard, tokens) -> str:
+    await self.ensure_shard(shard)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
 
 
-    if h.shape == (1,):
-      start_pos += n_captured_toks
-      start_pos += 1
-      n_captured_toks = 0
-      return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
-    else:
-      return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
+  async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
+    await self.ensure_shard(shard)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
 
 
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
-    model_path = await self.shard_downloader.ensure_shard(shard)
+    model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
 
     if self.shard != shard:
     if self.shard != shard:
-      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
+      loop = asyncio.get_running_loop()
+      parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
+      model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
 
 
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.tokenizer = await resolve_tokenizer(tokenizer_path)
       self.shard = shard
       self.shard = shard
+      self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard) 

+ 61 - 36
exo/inference/tinygrad/models/llama.py

@@ -1,11 +1,23 @@
-from typing import Tuple, Union, Optional, Dict, Any
+from typing import Tuple, Union, Optional, Dict, Any, List
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
 from tinygrad.helpers import getenv
 from tinygrad.helpers import getenv
+from collections import OrderedDict
 
 
 
 
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
 # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
-def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half) -> Tensor:
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
   freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
   freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
+
+  if rope_scaling:
+    factor = rope_scaling.get('factor', 1.0)
+    low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
+    high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
+    original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
+
+    freqs[:dim // 4] *= low_freq_factor
+    freqs[dim // 4:] = freqs[dim // 4:].contiguous()*high_freq_factor
+    freqs *= (original_max_pos_emb/end)**(1.0/factor)
+
   freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
   # TODO: move dtype outside this
   # TODO: move dtype outside this
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
   return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
@@ -36,7 +48,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   # NOTE: this is different from x.repeat((1, 1, n_rep, 1))
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
   return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
 
 
-
 class Attention:
 class Attention:
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
   def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
     self.n_heads = n_heads
     self.n_heads = n_heads
@@ -50,7 +61,7 @@ class Attention:
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
     self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
     if getenv("WQKV"):
     if getenv("WQKV"):
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
       xqkv = x @ self.wqkv.T
       xqkv = x @ self.wqkv.T
@@ -65,19 +76,16 @@ class Attention:
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
     bsz, seqlen, _, _ = xq.shape
     bsz, seqlen, _, _ = xq.shape
 
 
-    # create kv cache
-    if not hasattr(self, "cache_kv"):
-      self.cache_kv = Tensor.zeros(2, bsz, self.max_context, self.n_kv_heads, self.head_dim, dtype=x.dtype).contiguous().realize()
-      if isinstance(x.device, tuple):
-        # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
-        self.cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+    if cache is not None:
+      # update the cache
+      assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
+      cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
 
 
-    # update the cache
-    assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
-    self.cache_kv.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
-
-    keys = self.cache_kv[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
-    values = self.cache_kv[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+      keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
+      values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
+    else:
+      keys = xk
+      values = xv
 
 
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
     xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
@@ -103,13 +111,13 @@ class TransformerBlock:
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.attention_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
     self.ffn_norm = nn.RMSNorm(dim, norm_eps)
 
 
-  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor]):
-    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
+  def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
+    h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
     return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
 
 
 
 
 # standard openai sampling
 # standard openai sampling
-def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
+def sample_logits(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
   assert logits.ndim == 1, "only works on 1d tensors"
   assert logits.ndim == 1, "only works on 1d tensors"
   assert 0 <= p <= 1, "p must be between 0 and 1"
   assert 0 <= p <= 1, "p must be between 0 and 1"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
   assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
@@ -176,42 +184,56 @@ class Transformer:
     rope_theta=10000,
     rope_theta=10000,
     max_context=1024,
     max_context=1024,
     jit=True,
     jit=True,
-    feed_forward=FeedForward
+    feed_forward=FeedForward,
+    rope_scaling: Optional[Dict[str, float]] = None,
+    tie_word_embeddings=False,
   ):
   ):
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.norm = nn.RMSNorm(dim, norm_eps)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.tok_embeddings = nn.Embedding(vocab_size, dim)
     self.output = nn.Linear(dim, vocab_size, bias=False)
     self.output = nn.Linear(dim, vocab_size, bias=False)
+    if tie_word_embeddings:
+      self.output.weight = self.tok_embeddings.weight
     self.max_context = max_context
     self.max_context = max_context
-    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta).contiguous()
-    self.forward_jit = TinyJit(self.forward) if jit else None
+    self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
+    self.forward_jit = TinyJit(self.forward_base) if jit else None
     self.shard = shard
     self.shard = shard
 
 
-  def forward(self, x: Tensor, start_pos: Union[Variable, int], temperature: float, top_k: int, top_p: float, alpha_f: float, alpha_p: float):
+  def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
     seqlen = x.shape[1]
     seqlen = x.shape[1]
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
     mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
 
 
-    if self.shard.is_first_layer():
-      h = self.tok_embeddings(x)
-    else:
-      h = x
+    h = x
 
 
-    for i in range(self.shard.start_layer, self.shard.end_layer + 1):
+    if cache is None:
+      cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]  
+    for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
       layer = self.layers[i]
       layer = self.layers[i]
-      h = layer(h, start_pos, freqs_cis, mask)
+      h = layer(h, start_pos, freqs_cis, mask, cache=c)
 
 
     if self.shard.is_last_layer():
     if self.shard.is_last_layer():
-      logits = self.output(self.norm(h)).float()[:, -1, :]
-      return sample(logits.flatten(), temperature, top_k, top_p, alpha_f, alpha_p).realize()
+      logits = self.output(self.norm(h)).float().realize()
+      return logits
     else:
     else:
       return h
       return h
 
 
-  def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0, top_k: int = 0, top_p: float = 0.8, alpha_f: float = 0.0, alpha_p: float = 0.0):
+  def embed(self, inputs: Tensor):
+    if self.shard.is_first_layer():
+      h = self.tok_embeddings(inputs)
+    else:
+      h = inputs
+    return h
+
+  def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
+    if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
+      return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
+    return self.forward_base(x, start_pos, cache=cache)
+
+  def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
     # TODO: better way to handle the first call v.s. the rest?
     # TODO: better way to handle the first call v.s. the rest?
-    if tokens.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(tokens, Variable("start_pos", 0, self.max_context).bind(start_pos), temperature, top_k, top_p, alpha_f, alpha_p)
-    return self.forward(tokens, start_pos, temperature, top_k, top_p, alpha_f, alpha_p)
+    h = self.embed(x)
+    return self.forward(h, start_pos, cache=cache)
 
 
 
 
 # *** helpers ***
 # *** helpers ***
@@ -245,7 +267,10 @@ def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_h
         v = permute(v, n_heads)
         v = permute(v, n_heads)
       elif "k_proj" in k:
       elif "k_proj" in k:
         v = permute(v, n_kv_heads)
         v = permute(v, n_kv_heads)
-    sd[keymap[k]] = v
+    if k in keymap:
+      sd[keymap[k]] = v
+    else:
+      sd[k] = v
   return sd
   return sd
 
 
 
 

+ 42 - 0
exo/inference/tinygrad/stateful_model.py

@@ -0,0 +1,42 @@
+from tinygrad import Tensor, Variable 
+from collections import OrderedDict
+from typing import List
+
+def create_kv_cache(x: Tensor, max_context: int, n_kv_heads: int, head_dim: int):
+  cache_kv = Tensor.zeros(2, x.shape[0], max_context, n_kv_heads, head_dim, dtype=x.dtype).contiguous().realize()
+  if isinstance(x.device, tuple):
+    # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
+    cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
+  return cache_kv.realize()
+
+class ModelState:
+  cache: List[Tensor]
+  start: int 
+  def __init__(self, cache: List[Tensor], start: int = 0):
+    self.cache = cache
+    self.start = start
+
+class StatefulModel:
+  def __init__(self, model, max_states: int = 2):
+    super().__init__()
+    self.model = model
+    self.max_states = max_states
+    self.states = OrderedDict()
+ 
+  def init_cache(self, x: Tensor, request_id: str):
+    cache = [create_kv_cache(x, self.model.layers[i].attention.max_context, self.model.layers[i].attention.n_kv_heads, self.model.layers[i].attention.head_dim) for i in range(self.model.shard.start_layer, self.model.shard.end_layer + 1)]
+    if len(self.states) >= self.max_states:
+      self.states.popitem(last=False)
+
+    self.states[request_id] = ModelState(cache)
+
+  def __call__(self, x: Tensor, request_id: str): 
+    h = self.model.embed(x)
+    if request_id not in self.states:
+      self.init_cache(h, request_id)
+    else:
+      self.states.move_to_end(request_id)
+    out = self.model.forward(h, self.states[request_id].start, cache=self.states[request_id].cache)
+    self.states[request_id].start += h.shape[1]
+    return out
+

+ 15 - 0
exo/inference/tokenizers.py

@@ -7,7 +7,21 @@ from transformers import AutoTokenizer, AutoProcessor
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
 
 
+
+class DummyTokenizer:
+  def __init__(self):
+    self.eos_token_id = 0
+
+  def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
+    return [1, 2, 3]
+
+  def decode(self, tokens):
+    return "dummy"
+
+
 async def resolve_tokenizer(model_id: str):
 async def resolve_tokenizer(model_id: str):
+  if model_id == "dummy":
+    return DummyTokenizer()
   local_path = await get_local_snapshot_dir(model_id)
   local_path = await get_local_snapshot_dir(model_id)
   if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
   if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
   try:
   try:
@@ -19,6 +33,7 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 5: traceback.print_exc()
     if DEBUG >= 5: traceback.print_exc()
   return await _resolve_tokenizer(model_id)
   return await _resolve_tokenizer(model_id)
 
 
+
 async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
 async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
   try:
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")

+ 98 - 47
exo/main.py

@@ -1,11 +1,17 @@
 import argparse
 import argparse
 import asyncio
 import asyncio
+import atexit
 import signal
 import signal
 import json
 import json
+import logging
+import platform
+import os
+import sys
 import time
 import time
 import traceback
 import traceback
 import uuid
 import uuid
-import sys
+from exo.networking.manual.manual_discovery import ManualDiscovery
+from exo.networking.manual.network_topology_config import NetworkTopology
 from exo.orchestration.standard_node import StandardNode
 from exo.orchestration.standard_node import StandardNode
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.udp.udp_discovery import UDPDiscovery
 from exo.networking.udp.udp_discovery import UDPDiscovery
@@ -13,49 +19,57 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
 from exo.api import ChatGPTAPI
 from exo.api import ChatGPTAPI
-from exo.download.shard_download import ShardDownloader, RepoProgressEvent
+from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
-from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link
+from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses, terminal_link, shutdown
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration.node import Node
 from exo.orchestration.node import Node
-from exo.models import model_base_shards
+from exo.models import build_base_shard, get_repo
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
 
 
 # parse args
 # parse args
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
+parser.add_argument("--default-model", type=str, default=None, help="Default model")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
+parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
 parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
-parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale"], default="udp", help="Discovery module to use")
+parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
 parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
+parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
 parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
-parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
+parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
 parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
-parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use")
+parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 args = parser.parse_args()
 args = parser.parse_args()
+print(f"Selected inference engine: {args.inference_engine}")
 
 
 print_yellow_exo()
 print_yellow_exo()
 
 
 system_info = get_system_info()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 print(f"Detected system: {system_info}")
 
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads)
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
+                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
+print(f"Inference engine name after selection: {inference_engine_name}")
+
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
 
 
@@ -75,9 +89,27 @@ if DEBUG >= 0:
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
     print(f" - {terminal_link(chatgpt_api_endpoint)}")
 
 
 if args.discovery_module == "udp":
 if args.discovery_module == "udp":
-  discovery = UDPDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout)
+  discovery = UDPDiscovery(
+    args.node_id,
+    args.node_port,
+    args.listen_port,
+    args.broadcast_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout
+  )
 elif args.discovery_module == "tailscale":
 elif args.discovery_module == "tailscale":
-  discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
+  discovery = TailscaleDiscovery(
+    args.node_id,
+    args.node_port,
+    lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities),
+    discovery_timeout=args.discovery_timeout,
+    tailscale_api_key=args.tailscale_api_key,
+    tailnet=args.tailnet_name
+  )
+elif args.discovery_module == "manual":
+  if not args.discovery_config_path:
+    raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
+  discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
 node = StandardNode(
 node = StandardNode(
   args.node_id,
   args.node_id,
@@ -86,32 +118,35 @@ node = StandardNode(
   discovery,
   discovery,
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   max_generate_tokens=args.max_generate_tokens,
-  topology_viz=topology_viz
+  topology_viz=topology_viz,
+  shard_downloader=shard_downloader
 )
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
-node.download_progress = {}  # Initialize download progress tracking
-node.node_download_progress = {}  # For tracking per-node download progress
 api = ChatGPTAPI(
 api = ChatGPTAPI(
   node,
   node,
   inference_engine.__class__.__name__,
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
   response_timeout=args.chatgpt_api_response_timeout,
-  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
+  default_model=args.default_model
 )
 )
 node.on_token.register("update_topology_viz").on_next(
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
 )
 )
+
 def preemptively_start_download(request_id: str, opaque_status: str):
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
   try:
     status = json.loads(opaque_status)
     status = json.loads(opaque_status)
     if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
     if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
       current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
       current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
       if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
       if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
-      asyncio.create_task(shard_downloader.ensure_shard(current_shard))
+      asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
   except Exception as e:
   except Exception as e:
     if DEBUG >= 2:
     if DEBUG >= 2:
       print(f"Failed to preemptively start download: {e}")
       print(f"Failed to preemptively start download: {e}")
       traceback.print_exc()
       traceback.print_exc()
+
+
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
 
 
 if args.prometheus_client_port:
 if args.prometheus_client_port:
@@ -120,41 +155,24 @@ if args.prometheus_client_port:
 
 
 last_broadcast_time = 0
 last_broadcast_time = 0
 
 
-def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-    global last_broadcast_time
-    current_time = time.time()
-    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-        last_broadcast_time = current_time
-        node.download_progress[event.repo_id] = event.to_dict()
-        node.node_download_progress[node.id] = event.to_dict()
-        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
-            "type": "download_progress",
-            "node_id": node.id,
-            "progress": event.to_dict()
-        })))
-
-shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
+def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
 
 
-async def shutdown(signal, loop):
-  """Gracefully shutdown the server and close the asyncio loop."""
-  print(f"Received exit signal {signal.name}...")
-  print("Thank you for using exo.")
-  print_yellow_exo()
-  server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
-  [task.cancel() for task in server_tasks]
-  print(f"Cancelling {len(server_tasks)} outstanding tasks")
-  await asyncio.gather(*server_tasks, return_exceptions=True)
-  await server.stop()
-  loop.stop()
 
 
+shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
 
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
-  shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
+  inference_class = inference_engine.__class__.__name__
+  shard = build_base_shard(model_name, inference_class)
   if not shard:
   if not shard:
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
     return
     return
-  tokenizer = await resolve_tokenizer(shard.model_id)
+  tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
   request_id = str(uuid.uuid4())
   request_id = str(uuid.uuid4())
   callback_id = f"cli-wait-response-{request_id}"
   callback_id = f"cli-wait-response-{request_id}"
   callback = node.on_token.register(callback_id)
   callback = node.on_token.register(callback_id)
@@ -164,7 +182,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
 
 
   try:
   try:
     print(f"Processing prompt: {prompt}")
     print(f"Processing prompt: {prompt}")
-    await node.process_prompt(shard, prompt, None, request_id=request_id)
+    await node.process_prompt(shard, prompt, request_id=request_id)
 
 
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
     _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
 
 
@@ -176,16 +194,48 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
   finally:
   finally:
     node.on_token.deregister(callback_id)
     node.on_token.deregister(callback_id)
 
 
+def clean_path(path):
+    """Clean and resolve path"""
+    if path.startswith("Optional("):
+        path = path.strip('Optional("').rstrip('")')
+    return os.path.expanduser(path)
 
 
 async def main():
 async def main():
   loop = asyncio.get_running_loop()
   loop = asyncio.get_running_loop()
 
 
+  # Check HuggingFace directory permissions
+  hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
+  if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
+  print(f"{has_read=}, {has_write=}")
+  if not has_read or not has_write:
+    print(f"""
+          WARNING: Limited permissions for model storage directory: {hf_home}.
+          This may prevent model downloads from working correctly.
+          {"❌ No read access" if not has_read else ""}
+          {"❌ No write access" if not has_write else ""}
+          """)
+    
+  if not args.models_seed_dir is None:
+    try:
+      models_seed_dir = clean_path(args.models_seed_dir)
+      await move_models_to_hf(models_seed_dir)
+    except Exception as e:
+      print(f"Error moving models to .cache/huggingface: {e}")
+
+  def restore_cursor():
+    if platform.system() != "Windows":
+        os.system("tput cnorm")  # Show cursor
+
+  # Restore the cursor when the program exits
+  atexit.register(restore_cursor)
+
   # Use a more direct approach to handle signals
   # Use a more direct approach to handle signals
   def handle_exit():
   def handle_exit():
-    asyncio.ensure_future(shutdown(signal.SIGTERM, loop))
+    asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server))
 
 
-  for s in [signal.SIGINT, signal.SIGTERM]:
-    loop.add_signal_handler(s, handle_exit)
+  if platform.system() != "Windows":
+    for s in [signal.SIGINT, signal.SIGTERM]:
+      loop.add_signal_handler(s, handle_exit)
 
 
   await node.start(wait_for_peers=args.wait_for_peers)
   await node.start(wait_for_peers=args.wait_for_peers)
 
 
@@ -208,8 +258,9 @@ def run():
   except KeyboardInterrupt:
   except KeyboardInterrupt:
     print("Received keyboard interrupt. Shutting down...")
     print("Received keyboard interrupt. Shutting down...")
   finally:
   finally:
-    loop.run_until_complete(shutdown(signal.SIGTERM, loop))
+    loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
     loop.close()
     loop.close()
 
 
+
 if __name__ == "__main__":
 if __name__ == "__main__":
   run()
   run()

+ 126 - 40
exo/models.py

@@ -1,62 +1,148 @@
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
+from typing import Optional, List
 
 
-model_base_shards = {
+model_cards = {
   ### llama
   ### llama
   "llama-3.2-1b": {
   "llama-3.2-1b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),
+    "layers": 16,
+    "repo": {
+      "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
+      "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+    },
   },
   },
   "llama-3.2-3b": {
   "llama-3.2-3b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
+    "layers": 28,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+    },
   },
   },
   "llama-3.1-8b": {
   "llama-3.1-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
+    },
   },
   },
   "llama-3.1-70b": {
   "llama-3.1-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
   },
   "llama-3.1-70b-bf16": {
   "llama-3.1-70b-bf16": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
+       "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+    },
   },
   },
-  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3-8b": {
   "llama-3-8b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+    "layers": 32,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+    },
   },
   },
   "llama-3-70b": {
   "llama-3-70b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+    "layers": 80,
+    "repo": {
+       "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
+       "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
+    },
   },
   },
+  "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
+  "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
   ### mistral
   ### mistral
-  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
-  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
+  "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
+  "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
   ### deepseek
   ### deepseek
-  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
-  "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),},
+  "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
+  "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
   ### llava
   ### llava
-  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
+  "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
   ### qwen
   ### qwen
-  "qwen-2.5-coder-1.5b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-coder-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-math-7b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),
-  },
-  "qwen-2.5-14b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),
-  },
-  "qwen-2.5-72b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
-  "qwen-2.5-math-72b": {
-    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-  },
+  "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
+  "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
+  "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
+  "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
+  "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
+  "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
+  "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
+  "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
+  "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
+  "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
+  ### nemotron
+  "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
+  "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
+  # gemma
+  "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
+  "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
+  # dummy
+  "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
+}
+
+pretty_name = {
+  "llama-3.2-1b": "Llama 3.2 1B",
+  "llama-3.2-3b": "Llama 3.2 3B",
+  "llama-3.1-8b": "Llama 3.1 8B",
+  "llama-3.1-70b": "Llama 3.1 70B",
+  "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
+  "llama-3.1-405b": "Llama 3.1 405B",
+  "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
+  "gemma2-9b": "Gemma2 9B",
+  "gemma2-27b": "Gemma2 27B",
+  "nemotron-70b": "Nemotron 70B",
+  "nemotron-70b-bf16": "Nemotron 70B (BF16)",
+  "mistral-nemo": "Mistral Nemo",
+  "mistral-large": "Mistral Large",
+  "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
+  "deepseek-coder-v2.5": "Deepseek Coder V2.5",
+  "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
+  "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
+  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
+  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
+  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
+  "qwen-2.5-7b": "Qwen 2.5 7B",
+  "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
+  "qwen-2.5-14b": "Qwen 2.5 14B",
+  "qwen-2.5-72b": "Qwen 2.5 72B",
+  "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "llama-3-8b": "Llama 3 8B",
+  "llama-3-70b": "Llama 3 70B",
 }
 }
+
+def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
+  return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
+
+def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
+  repo = get_repo(model_id, inference_engine_classname)
+  n_layers = model_cards.get(model_id, {}).get("layers", 0)
+  if repo is None or n_layers < 1:
+    return None
+  return Shard(model_id, 0, 0, n_layers)
+
+def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
+  if not supported_inference_engine_lists:
+    return list(model_cards.keys())
+
+  from exo.inference.inference_engine import inference_engine_classes
+  supported_inference_engine_lists = [
+    [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
+    for engine_list in supported_inference_engine_lists
+  ]
+
+  def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
+    return any(engine in model_info.get("repo", {}) for engine in engine_list)
+
+  def supports_all_engine_lists(model_info: dict) -> bool:
+    return all(has_any_engine(model_info, engine_list)
+              for engine_list in supported_inference_engine_lists)
+
+  return [
+    model_id for model_id, model_info in model_cards.items()
+    if supports_all_engine_lists(model_info)
+  ]

+ 12 - 9
exo/networking/grpc/grpc_peer_handle.py

@@ -9,7 +9,7 @@ from . import node_service_pb2_grpc
 from ..peer_handle import PeerHandle
 from ..peer_handle import PeerHandle
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
-from exo.topology.device_capabilities import DeviceCapabilities
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
 from exo.helpers import DEBUG
 
 
 
 
@@ -32,7 +32,11 @@ class GRPCPeerHandle(PeerHandle):
 
 
   async def connect(self):
   async def connect(self):
     if self.channel is None:
     if self.channel is None:
-      self.channel = grpc.aio.insecure_channel(self.address, options=[("grpc.max_metadata_size", 32*1024*1024)])
+      self.channel = grpc.aio.insecure_channel(self.address, options=[
+        ("grpc.max_metadata_size", 32*1024*1024),
+        ('grpc.max_receive_message_length', 32*1024*1024),
+        ('grpc.max_send_message_length', 32*1024*1024)
+      ])
       self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
       self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
     await self.channel.channel_ready()
     await self.channel.channel_ready()
 
 
@@ -56,17 +60,16 @@ class GRPCPeerHandle(PeerHandle):
       return response.is_healthy
       return response.is_healthy
     except asyncio.TimeoutError:
     except asyncio.TimeoutError:
       return False
       return False
-    except:
+    except Exception:
       if DEBUG >= 4:
       if DEBUG >= 4:
         print(f"Health check failed for {self._id}@{self.address}.")
         print(f"Health check failed for {self._id}@{self.address}.")
         import traceback
         import traceback
         traceback.print_exc()
         traceback.print_exc()
       return False
       return False
 
 
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.PromptRequest(
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       prompt=prompt,
-      image_str=image_str,
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         model_id=shard.model_id,
         start_layer=shard.start_layer,
         start_layer=shard.start_layer,
@@ -74,7 +77,6 @@ class GRPCPeerHandle(PeerHandle):
         n_layers=shard.n_layers,
         n_layers=shard.n_layers,
       ),
       ),
       request_id=request_id,
       request_id=request_id,
-      inference_state=inference_state,
     )
     )
     response = await self.stub.SendPrompt(request)
     response = await self.stub.SendPrompt(request)
 
 
@@ -83,7 +85,7 @@ class GRPCPeerHandle(PeerHandle):
 
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
 
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         model_id=shard.model_id,
@@ -93,7 +95,6 @@ class GRPCPeerHandle(PeerHandle):
       ),
       ),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
       request_id=request_id,
-      inference_state=inference_state,
     )
     )
     response = await self.stub.SendTensor(request)
     response = await self.stub.SendTensor(request)
 
 
@@ -117,7 +118,9 @@ class GRPCPeerHandle(PeerHandle):
     response = await self.stub.CollectTopology(request)
     response = await self.stub.CollectTopology(request)
     topology = Topology()
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
     for node_id, capabilities in response.nodes.items():
-      device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
+      device_capabilities = DeviceCapabilities(
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+      )
       topology.update_node(node_id, device_capabilities)
       topology.update_node(node_id, device_capabilities)
     for node_id, peers in response.peer_graph.items():
     for node_id, peers in response.peer_graph.items():
       for peer_id in peers.peer_ids:
       for peer_id in peers.peer_ids:

+ 3 - 5
exo/networking/grpc/grpc_server.py

@@ -49,10 +49,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       n_layers=request.shard.n_layers,
       n_layers=request.shard.n_layers,
     )
     )
     prompt = request.prompt
     prompt = request.prompt
-    image_str = request.image_str
     request_id = request.request_id
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, image_str, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {image_str=} {request_id=} result: {result}")
+    result = await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
 
 
@@ -65,9 +64,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     )
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
     request_id = request.request_id
-    inference_state = request.inference_state
 
 
-    result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
+    result = await self.node.process_tensor(shard, tensor, request_id)
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
     tensor_data = result.tobytes() if result is not None else None
     tensor_data = result.tobytes() if result is not None else None
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
     return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()

+ 2 - 5
exo/networking/grpc/node_service.proto

@@ -22,16 +22,13 @@ message Shard {
 message PromptRequest {
 message PromptRequest {
   Shard shard = 1;
   Shard shard = 1;
   string prompt = 2;
   string prompt = 2;
-  optional string image_str = 3;
-  optional string request_id = 4;
-  optional string inference_state = 5;
+  optional string request_id = 3;
 }
 }
 
 
 message TensorRequest {
 message TensorRequest {
   Shard shard = 1;
   Shard shard = 1;
   Tensor tensor = 2;
   Tensor tensor = 2;
   optional string request_id = 3;
   optional string request_id = 3;
-  optional string inference_state = 4;
 }
 }
 
 
 message GetInferenceResultRequest {
 message GetInferenceResultRequest {
@@ -93,4 +90,4 @@ message HealthCheckResponse {
   bool is_healthy = 1;
   bool is_healthy = 1;
 }
 }
 
 
-message Empty {}
+message Empty {}

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 0 - 0
exo/networking/manual/__init__.py


+ 71 - 0
exo/networking/manual/manual_discovery.py

@@ -0,0 +1,71 @@
+import asyncio
+from exo.networking.discovery import Discovery
+from typing import Dict, List, Callable
+
+from exo.topology.device_capabilities import DeviceCapabilities
+from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
+from exo.helpers import DEBUG_DISCOVERY
+from exo.networking.peer_handle import PeerHandle
+
+
+class ManualDiscovery(Discovery):
+  def __init__(
+    self,
+    network_config_path: str,
+    node_id: str,
+    create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
+  ):
+    self.topology = NetworkTopology.from_path(network_config_path)
+    self.create_peer_handle = create_peer_handle
+
+    if node_id not in self.topology.peers:
+      raise ValueError(
+        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
+      )
+
+    self.listen_task = None
+
+    self.known_peers: Dict[str, PeerHandle] = {}
+    self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
+    self.peers_in_network.pop(node_id)
+
+  async def start(self) -> None:
+    self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
+
+  async def stop(self) -> None:
+    if self.listen_task:
+      self.listen_task.cancel()
+
+  async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
+    if wait_for_peers > 0:
+      while len(self.known_peers) < wait_for_peers:
+        if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
+        await asyncio.sleep(0.1)
+    if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
+    return list(self.known_peers.values())
+
+  async def task_find_peers_from_config(self):
+    if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
+    while True:
+      for peer_id, peer_config in self.peers_in_network.items():
+        try:
+          if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
+          peer = self.known_peers.get(peer_id)
+          if not peer:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
+          is_healthy = await peer.health_check()
+          if is_healthy:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
+            self.known_peers[peer_id] = peer
+          else:
+            if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
+            try:
+              del self.known_peers[peer_id]
+            except KeyError:
+              pass
+        except Exception as e:
+          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+      await asyncio.sleep(1.0)
+
+      if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")

+ 31 - 0
exo/networking/manual/network_topology_config.py

@@ -0,0 +1,31 @@
+from typing import Dict
+from pydantic import BaseModel, ValidationError
+
+from exo.topology.device_capabilities import DeviceCapabilities
+
+
+class PeerConfig(BaseModel):
+  address: str
+  port: int
+  device_capabilities: DeviceCapabilities
+
+
+class NetworkTopology(BaseModel):
+  """Configuration of the network. A collection outlining all nodes in the network, including the node this is running from."""
+
+  peers: Dict[str, PeerConfig]
+  """
+  node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
+  """
+  @classmethod
+  def from_path(cls, path: str) -> "NetworkTopology":
+    try:
+      with open(path, "r") as f:
+        config_data = f.read()
+    except FileNotFoundError as e:
+      raise FileNotFoundError(f"Config file not found at {path}") from e
+
+    try:
+      return cls.model_validate_json(config_data)
+    except ValidationError as e:
+      raise ValueError(f"Error validating network topology config from {path}: {e}") from e

+ 17 - 0
exo/networking/manual/test_data/invalid_config.json

@@ -0,0 +1,17 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 0 - 0
exo/networking/manual/test_data/invalid_json.json


+ 32 - 0
exo/networking/manual/test_data/test_config.json

@@ -0,0 +1,32 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "port": 50051,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    },
+    "node2": {
+      "address": "localhost",
+      "port": 50052,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 18 - 0
exo/networking/manual/test_data/test_config_single_node.json

@@ -0,0 +1,18 @@
+{
+  "peers": {
+    "node1": {
+      "address": "localhost",
+      "port": 50051,
+      "device_capabilities": {
+        "model": "Unknown Model",
+        "chip": "Unknown Chip",
+        "memory": 0,
+        "flops": {
+          "fp32": 0,
+          "fp16": 0,
+          "int8": 0
+        }
+      }
+    }
+  }
+}

+ 103 - 0
exo/networking/manual/test_manual_discovery.py

@@ -0,0 +1,103 @@
+import asyncio
+import unittest
+from unittest import mock
+from exo.networking.manual.manual_discovery import ManualDiscovery
+from exo.networking.manual.network_topology_config import NetworkTopology
+from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
+from exo.networking.grpc.grpc_server import GRPCServer
+from exo.orchestration.node import Node
+
+root_path = "./exo/networking/manual/test_data/test_config.json"
+
+
+class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    _ = self.discovery1.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=0)
+    assert len(peers1) == 0
+
+    self.peer1.connect.assert_not_called()
+
+
+class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    self.peer1 = mock.AsyncMock()
+    self.peer2 = mock.AsyncMock()
+    self.peer1.connect = mock.AsyncMock()
+    self.peer2.connect = mock.AsyncMock()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer1)
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: self.peer2)
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+
+  async def test_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # connect has to be explicitly called after discovery
+    self.peer1.connect.assert_not_called()
+    self.peer2.connect.assert_not_called()
+
+
+class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
+  async def asyncSetUp(self):
+    config = NetworkTopology.from_path(root_path)
+
+    self.node1 = mock.AsyncMock(spec=Node)
+    self.node2 = mock.AsyncMock(spec=Node)
+    self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port)
+    self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
+    await self.server1.start()
+    await self.server2.start()
+    self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities))
+    await self.discovery1.start()
+    await self.discovery2.start()
+
+  async def asyncTearDown(self):
+    await self.discovery1.stop()
+    await self.discovery2.stop()
+    await self.server1.stop()
+    await self.server2.stop()
+
+  async def test_grpc_discovery(self):
+    peers1 = await self.discovery1.discover_peers(wait_for_peers=1)
+    assert len(peers1) == 1
+    peers2 = await self.discovery2.discover_peers(wait_for_peers=1)
+    assert len(peers2) == 1
+
+    # Connect
+    await peers1[0].connect()
+    await peers2[0].connect()
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertTrue(await peers2[0].is_connected())
+
+    # Kill server1
+    await self.server1.stop()
+
+    self.assertTrue(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+    # Kill server2
+    await self.server2.stop()
+
+    self.assertFalse(await peers1[0].is_connected())
+    self.assertFalse(await peers2[0].is_connected())
+
+
+if __name__ == "__main__":
+  asyncio.run(unittest.main())

+ 49 - 0
exo/networking/manual/test_network_topology_config.py

@@ -0,0 +1,49 @@
+import unittest
+
+from exo.networking.manual.network_topology_config import NetworkTopology
+
+root_path = "./exo/networking/manual/test_data/"
+
+
+class TestNetworkTopologyConfig(unittest.TestCase):
+  def test_from_path_invalid_path(self):
+    with self.assertRaises(FileNotFoundError) as e:
+      NetworkTopology.from_path("invalid_path")
+    self.assertEqual(str(e.exception), "Config file not found at invalid_path")
+
+  def test_from_path_invalid_json(self):
+    with self.assertRaises(ValueError) as e:
+      NetworkTopology.from_path(root_path + "invalid_json.json")
+    self.assertIn("Error validating network topology config from", str(e.exception))
+    self.assertIn("1 validation error for NetworkTopology\n  Invalid JSON: EOF while parsing a value at line 1 column 0", str(e.exception))
+
+  def test_from_path_invalid_config(self):
+    with self.assertRaises(ValueError) as e:
+      NetworkTopology.from_path(root_path + "invalid_config.json")
+    self.assertIn("Error validating network topology config from", str(e.exception))
+    self.assertIn("port\n  Field required", str(e.exception))
+
+  def test_from_path_valid(self):
+    config = NetworkTopology.from_path(root_path + "test_config.json")
+
+    self.assertEqual(config.peers["node1"].port, 50051)
+    self.assertEqual(config.peers["node1"].device_capabilities.model, "Unknown Model")
+    self.assertEqual(config.peers["node1"].address, "localhost")
+    self.assertEqual(config.peers["node1"].device_capabilities.chip, "Unknown Chip")
+    self.assertEqual(config.peers["node1"].device_capabilities.memory, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp32, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.fp16, 0)
+    self.assertEqual(config.peers["node1"].device_capabilities.flops.int8, 0)
+
+    self.assertEqual(config.peers["node2"].port, 50052)
+    self.assertEqual(config.peers["node2"].device_capabilities.model, "Unknown Model")
+    self.assertEqual(config.peers["node2"].address, "localhost")
+    self.assertEqual(config.peers["node2"].device_capabilities.chip, "Unknown Chip")
+    self.assertEqual(config.peers["node2"].device_capabilities.memory, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp32, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.fp16, 0)
+    self.assertEqual(config.peers["node2"].device_capabilities.flops.int8, 0)
+
+
+if __name__ == "__main__":
+  unittest.main()

+ 3 - 2
exo/networking/peer_handle.py

@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 from exo.topology.topology import Topology
 
 
+
 class PeerHandle(ABC):
 class PeerHandle(ABC):
   @abstractmethod
   @abstractmethod
   def id(self) -> str:
   def id(self) -> str:
@@ -35,11 +36,11 @@ class PeerHandle(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def send_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod

+ 42 - 18
exo/networking/tailscale/tailscale_discovery.py

@@ -2,12 +2,12 @@ import asyncio
 import time
 import time
 import traceback
 import traceback
 from typing import List, Dict, Callable, Tuple
 from typing import List, Dict, Callable, Tuple
-from tailscale import Tailscale, Device
 from exo.networking.discovery import Discovery
 from exo.networking.discovery import Discovery
 from exo.networking.peer_handle import PeerHandle
 from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
-from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes
+from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
+
 
 
 class TailscaleDiscovery(Discovery):
 class TailscaleDiscovery(Discovery):
   def __init__(
   def __init__(
@@ -32,7 +32,8 @@ class TailscaleDiscovery(Discovery):
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
     self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
     self.discovery_task = None
     self.discovery_task = None
     self.cleanup_task = None
     self.cleanup_task = None
-    self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
+    self.tailscale_api_key = tailscale_api_key
+    self.tailnet = tailnet
     self._device_id = None
     self._device_id = None
     self.update_task = None
     self.update_task = None
 
 
@@ -61,27 +62,24 @@ class TailscaleDiscovery(Discovery):
     return self._device_id
     return self._device_id
 
 
   async def update_device_posture_attributes(self):
   async def update_device_posture_attributes(self):
-    await update_device_attributes(await self.get_device_id(), self.tailscale.api_key, self.node_id, self.node_port, self.device_capabilities)
+    await update_device_attributes(await self.get_device_id(), self.tailscale_api_key, self.node_id, self.node_port, self.device_capabilities)
 
 
   async def task_discover_peers(self):
   async def task_discover_peers(self):
     while True:
     while True:
       try:
       try:
-        devices: dict[str, Device] = await self.tailscale.devices()
+        devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
         current_time = time.time()
         current_time = time.time()
 
 
-        active_devices = {
-          name: device for name, device in devices.items()
-          if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30
-        }
+        active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
 
 
         if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
         if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
         if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
         if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
-        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time  - device.last_seen.timestamp()) for device in devices.values()])
+        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time - device.last_seen.timestamp()) for device in devices.values()])
 
 
         for device in active_devices.values():
         for device in active_devices.values():
           if device.name == self.node_id: continue
           if device.name == self.node_id: continue
           peer_host = device.addresses[0]
           peer_host = device.addresses[0]
-          peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale.api_key)
+          peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale_api_key)
           if not peer_id:
           if not peer_id:
             if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
             if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
             continue
             continue
@@ -133,16 +131,42 @@ class TailscaleDiscovery(Discovery):
     while True:
     while True:
       try:
       try:
         current_time = time.time()
         current_time = time.time()
-        peers_to_remove = [
-          peer_handle.id() for peer_handle, connected_at, last_seen in self.known_peers.values()
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or current_time - last_seen > self.discovery_timeout or not await peer_handle.health_check()
-        ]
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, {connected_at=}, {last_seen=}, health_check={await peer_handle.health_check()}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        peers_to_remove = []
+
+        peer_ids = list(self.known_peers.keys())
+        results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
+
+        for peer_id, should_remove in zip(peer_ids, results):
+          if should_remove: peers_to_remove.append(peer_id)
+
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}"
+              for peer_handle, connected_at, last_seen in self.known_peers.values()
+            }
+          )
+
         for peer_id in peers_to_remove:
         for peer_id in peers_to_remove:
-          if peer_id in self.known_peers: del self.known_peers[peer_id]
-          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity.")
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+            if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
       except Exception as e:
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(f"Error in cleanup peers: {e}")
         print(traceback.format_exc())
         print(traceback.format_exc())
       finally:
       finally:
         await asyncio.sleep(self.discovery_interval)
         await asyncio.sleep(self.discovery_interval)
+
+  async def check_peer(self, peer_id: str, current_time: float) -> bool:
+    peer_handle, connected_at, last_seen = self.known_peers.get(peer_id, (None, None, None))
+    if peer_handle is None: return False
+
+    try:
+      is_connected = await peer_handle.is_connected()
+      health_ok = await peer_handle.health_check()
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
+      return True
+
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
+    return should_remove

+ 49 - 20
exo/networking/tailscale/tailscale_helpers.py

@@ -2,17 +2,33 @@ import json
 import asyncio
 import asyncio
 import aiohttp
 import aiohttp
 import re
 import re
-from typing import Dict, Any, Tuple
+from typing import Dict, Any, Tuple, List, Optional
 from exo.helpers import DEBUG_DISCOVERY
 from exo.helpers import DEBUG_DISCOVERY
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
+from datetime import datetime, timezone
+
+
+class Device:
+  def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
+    self.device_id = device_id
+    self.name = name
+    self.addresses = addresses
+    self.last_seen = last_seen
+
+  @classmethod
+  def from_dict(cls, data: Dict[str, Any]) -> 'Device':
+    return cls(device_id=data.get('id', ''), name=data.get('name', ''), addresses=data.get('addresses', []), last_seen=cls.parse_datetime(data.get('lastSeen')))
+
+  @staticmethod
+  def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
+    if not date_string:
+      return None
+    return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
+
 
 
 async def get_device_id() -> str:
 async def get_device_id() -> str:
   try:
   try:
-    process = await asyncio.create_subprocess_exec(
-      'tailscale', 'status', '--json',
-      stdout=asyncio.subprocess.PIPE,
-      stderr=asyncio.subprocess.PIPE
-    )
+    process = await asyncio.create_subprocess_exec('tailscale', 'status', '--json', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
     stdout, stderr = await process.communicate()
     stdout, stderr = await process.communicate()
     if process.returncode != 0:
     if process.returncode != 0:
       raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
       raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
@@ -22,22 +38,16 @@ async def get_device_id() -> str:
   except Exception as e:
   except Exception as e:
     raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
     raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
 
 
+
 async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
 async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
   async with aiohttp.ClientSession() as session:
   async with aiohttp.ClientSession() as session:
     base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
     base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}',
-      'Content-Type': 'application/json'
-    }
+    headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
 
 
     attributes = {
     attributes = {
-      "custom:exo_node_id": node_id.replace('-', '_'),
-      "custom:exo_node_port": node_port,
-      "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
-      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model),
-      "custom:exo_device_capability_memory": str(device_capabilities.memory),
-      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16),
-      "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
+      "custom:exo_node_id": node_id.replace('-', '_'), "custom:exo_node_port": node_port, "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
+      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model), "custom:exo_device_capability_memory": str(device_capabilities.memory),
+      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16), "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
       "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
       "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
     }
     }
 
 
@@ -50,12 +60,11 @@ async def update_device_attributes(device_id: str, api_key: str, node_id: str, n
         else:
         else:
           print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
           print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
 
 
+
 async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
 async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
   async with aiohttp.ClientSession() as session:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
     url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}'
-    }
+    headers = {'Authorization': f'Bearer {api_key}'}
     async with session.get(url, headers=headers) as response:
     async with session.get(url, headers=headers) as response:
       if response.status == 200:
       if response.status == 200:
         data = await response.json()
         data = await response.json()
@@ -77,6 +86,7 @@ async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int,
         print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
         print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
         return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
         return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
 
 
+
 def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
 def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
   result = {}
   result = {}
   prefix = "custom:exo_"
   prefix = "custom:exo_"
@@ -89,8 +99,27 @@ def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
         result[attr_name] = float(value)
         result[attr_name] = float(value)
   return result
   return result
 
 
+
 def sanitize_attribute(value: str) -> str:
 def sanitize_attribute(value: str) -> str:
   # Replace invalid characters with underscores
   # Replace invalid characters with underscores
   sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
   sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
   # Truncate to 50 characters
   # Truncate to 50 characters
   return sanitized_value[:50]
   return sanitized_value[:50]
+
+
+async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
+  async with aiohttp.ClientSession() as session:
+    url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
+    headers = {"Authorization": f"Bearer {api_key}"}
+
+    async with session.get(url, headers=headers) as response:
+      response.raise_for_status()
+      data = await response.json()
+
+      devices = {}
+      for device_data in data.get("devices", []):
+        print("Device data: ", device_data)
+        device = Device.from_dict(device_data)
+        devices[device.name] = device
+
+      return devices

+ 2 - 0
exo/networking/tailscale/test_tailscale_discovery.py

@@ -5,6 +5,7 @@ from unittest import mock
 from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.peer_handle import PeerHandle
 from exo.networking.peer_handle import PeerHandle
 
 
+
 class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
 class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
   async def asyncSetUp(self):
     self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
     self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
@@ -37,5 +38,6 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     # Check if discovered peers are instances of GRPCPeerHandle
     # Check if discovered peers are instances of GRPCPeerHandle
     print(peers)
     print(peers)
 
 
+
 if __name__ == '__main__':
 if __name__ == '__main__':
   unittest.main()
   unittest.main()

+ 1 - 0
exo/networking/udp/test_udp_discovery.py

@@ -6,6 +6,7 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.networking.grpc.grpc_server import GRPCServer
 from exo.orchestration.node import Node
 from exo.orchestration.node import Node
 
 
+
 class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
 class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
   async def asyncSetUp(self):
     self.peer1 = mock.AsyncMock()
     self.peer1 = mock.AsyncMock()

+ 54 - 29
exo/networking/udp/udp_discovery.py

@@ -9,6 +9,7 @@ from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
 from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
 
 
+
 class ListenProtocol(asyncio.DatagramProtocol):
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
     super().__init__()
@@ -53,7 +54,7 @@ class UDPDiscovery(Discovery):
     self.broadcast_interval = broadcast_interval
     self.broadcast_interval = broadcast_interval
     self.discovery_timeout = discovery_timeout
     self.discovery_timeout = discovery_timeout
     self.device_capabilities = device_capabilities
     self.device_capabilities = device_capabilities
-    self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
+    self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
     self.broadcast_task = None
     self.broadcast_task = None
     self.listen_task = None
     self.listen_task = None
     self.cleanup_task = None
     self.cleanup_task = None
@@ -76,31 +77,27 @@ class UDPDiscovery(Discovery):
       while len(self.known_peers) < wait_for_peers:
       while len(self.known_peers) < wait_for_peers:
         if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
         if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
         await asyncio.sleep(0.1)
         await asyncio.sleep(0.1)
-    return [peer_handle for peer_handle, _, _ in self.known_peers.values()]
+    return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
 
 
   async def task_broadcast_presence(self):
   async def task_broadcast_presence(self):
-    message = json.dumps({
-      "type": "discovery",
-      "node_id": self.node_id,
-      "grpc_port": self.node_port,
-      "device_capabilities": self.device_capabilities.to_dict(),
-    })
-
-    if DEBUG_DISCOVERY >= 2:
-      print("Starting task_broadcast_presence...")
-      print(f"\nBroadcast message: {message}")
+    if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
 
 
     while True:
     while True:
       # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
       # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
       # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
       # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
       for addr in get_all_ip_addresses():
       for addr in get_all_ip_addresses():
+        message = json.dumps({
+          "type": "discovery",
+          "node_id": self.node_id,
+          "grpc_port": self.node_port,
+          "device_capabilities": self.device_capabilities.to_dict(),
+          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+        })
+        if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
+
         transport = None
         transport = None
         try:
         try:
-          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-            lambda: BroadcastProtocol(message, self.broadcast_port),
-            local_addr=(addr, 0),
-            family=socket.AF_INET
-          )
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
           if DEBUG_DISCOVERY >= 3:
           if DEBUG_DISCOVERY >= 3:
             print(f"Broadcasting presence at ({addr})")
             print(f"Broadcasting presence at ({addr})")
         except Exception as e:
         except Exception as e:
@@ -138,25 +135,31 @@ class UDPDiscovery(Discovery):
       peer_id = message["node_id"]
       peer_id = message["node_id"]
       peer_host = addr[0]
       peer_host = addr[0]
       peer_port = message["grpc_port"]
       peer_port = message["grpc_port"]
+      peer_prio = message["priority"]
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
       device_capabilities = DeviceCapabilities(**message["device_capabilities"])
 
 
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
       if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
+        if peer_id in self.known_peers:
+          existing_peer_prio = self.known_peers[peer_id][3]
+          if existing_peer_prio >= peer_prio:
+            if DEBUG >= 1:
+              print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
+            return
         new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
         new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
         if not await new_peer_handle.health_check():
         if not await new_peer_handle.health_check():
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.")
           return
           return
         if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
         if DEBUG >= 1: print(f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}")
-        self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time())
+        self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), peer_prio)
       else:
       else:
         if not await self.known_peers[peer_id][0].health_check():
         if not await self.known_peers[peer_id][0].health_check():
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
           if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.")
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           if peer_id in self.known_peers: del self.known_peers[peer_id]
           return
           return
-        self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time())
+        if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
 
 
   async def task_listen_for_peers(self):
   async def task_listen_for_peers(self):
-    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
-                                                            local_addr=("0.0.0.0", self.listen_port))
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
     if DEBUG_DISCOVERY >= 2: print("Started listen task")
     if DEBUG_DISCOVERY >= 2: print("Started listen task")
 
 
   async def task_cleanup_peers(self):
   async def task_cleanup_peers(self):
@@ -164,19 +167,41 @@ class UDPDiscovery(Discovery):
       try:
       try:
         current_time = time.time()
         current_time = time.time()
         peers_to_remove = []
         peers_to_remove = []
-        for peer_id, (peer_handle, connected_at, last_seen) in self.known_peers.items():
-          if (not await peer_handle.is_connected() and current_time - connected_at > self.discovery_timeout) or \
-             (current_time - last_seen > self.discovery_timeout) or \
-             (not await peer_handle.health_check()):
-            peers_to_remove.append(peer_id)
 
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", {peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, {connected_at=}, {last_seen=}" for peer_handle, connected_at, last_seen in self.known_peers.values()})
+        peer_ids = list(self.known_peers.keys())
+        results = await asyncio.gather(*[self.check_peer(peer_id, current_time) for peer_id in peer_ids], return_exceptions=True)
+
+        for peer_id, should_remove in zip(peer_ids, results):
+          if should_remove: peers_to_remove.append(peer_id)
+
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}"
+              for peer_handle, connected_at, last_seen, prio in self.known_peers.values()
+            }
+          )
 
 
         for peer_id in peers_to_remove:
         for peer_id in peers_to_remove:
-          if peer_id in self.known_peers: del self.known_peers[peer_id]
-          if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
+          if peer_id in self.known_peers:
+            del self.known_peers[peer_id]
+            if DEBUG_DISCOVERY >= 2: print(f"Removed peer {peer_id} due to inactivity or failed health check.")
       except Exception as e:
       except Exception as e:
         print(f"Error in cleanup peers: {e}")
         print(f"Error in cleanup peers: {e}")
         print(traceback.format_exc())
         print(traceback.format_exc())
       finally:
       finally:
         await asyncio.sleep(self.broadcast_interval)
         await asyncio.sleep(self.broadcast_interval)
+
+  async def check_peer(self, peer_id: str, current_time: float) -> bool:
+    peer_handle, connected_at, last_seen, prio = self.known_peers.get(peer_id, (None, None, None, None))
+    if peer_handle is None: return False
+
+    try:
+      is_connected = await peer_handle.is_connected()
+      health_ok = await peer_handle.health_check()
+    except Exception as e:
+      if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
+      return True
+
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
+    return should_remove

+ 2 - 2
exo/orchestration/node.py

@@ -16,11 +16,11 @@ class Node(ABC):
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def process_prompt(self, shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod
-  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     pass
     pass
 
 
   @abstractmethod
   @abstractmethod

+ 139 - 110
exo/orchestration/standard_node.py

@@ -4,7 +4,7 @@ import asyncio
 import uuid
 import uuid
 import time
 import time
 import traceback
 import traceback
-from typing import List, Dict, Optional, Tuple, Union
+from typing import List, Dict, Optional, Tuple, Union, Set
 from exo.networking import Discovery, PeerHandle, Server
 from exo.networking import Discovery, PeerHandle, Server
 from exo.inference.inference_engine import InferenceEngine, Shard
 from exo.inference.inference_engine import InferenceEngine, Shard
 from .node import Node
 from .node import Node
@@ -15,7 +15,8 @@ from exo import DEBUG
 from exo.helpers import AsyncCallbackSystem
 from exo.helpers import AsyncCallbackSystem
 from exo.viz.topology_viz import TopologyViz
 from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
-
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
 
 
 class StandardNode(Node):
 class StandardNode(Node):
   def __init__(
   def __init__(
@@ -27,6 +28,7 @@ class StandardNode(Node):
     partitioning_strategy: PartitioningStrategy = None,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
     max_generate_tokens: int = 1024,
     topology_viz: Optional[TopologyViz] = None,
     topology_viz: Optional[TopologyViz] = None,
+    shard_downloader: Optional[HFShardDownloader] = None,
   ):
   ):
     self.id = _id
     self.id = _id
     self.inference_engine = inference_engine
     self.inference_engine = inference_engine
@@ -37,12 +39,16 @@ class StandardNode(Node):
     self.topology: Topology = Topology()
     self.topology: Topology = Topology()
     self.device_capabilities = device_capabilities()
     self.device_capabilities = device_capabilities()
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
     self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+    self.buffered_logits: Dict[str, List[np.ndarray]] = {}
+    self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
     self.topology_viz = topology_viz
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
     self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+    self.topology_inference_engines_pool: List[List[str]] = []
+    self.shard_downloader = shard_downloader
 
 
   async def start(self, wait_for_peers: int = 0) -> None:
   async def start(self, wait_for_peers: int = 0) -> None:
     await self.server.start()
     await self.server.start()
@@ -59,6 +65,10 @@ class StandardNode(Node):
   def on_node_status(self, request_id, opaque_status):
   def on_node_status(self, request_id, opaque_status):
     try:
     try:
       status_data = json.loads(opaque_status)
       status_data = json.loads(opaque_status)
+      if status_data.get("type", "") == "supported_inference_engines":
+        node_id = status_data.get("node_id")
+        engines = status_data.get("engines", [])
+        self.topology_inference_engines_pool.append(engines)
       if status_data.get("type", "") == "node_status":
       if status_data.get("type", "") == "node_status":
         if status_data.get("status", "").startswith("start_"):
         if status_data.get("status", "").startswith("start_"):
           self.current_topology.active_node_id = status_data.get("node_id")
           self.current_topology.active_node_id = status_data.get("node_id")
@@ -76,7 +86,56 @@ class StandardNode(Node):
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: print(f"Error updating visualization: {e}")
       if DEBUG >= 1: traceback.print_exc()
       if DEBUG >= 1: traceback.print_exc()
 
 
-  async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  def get_supported_inference_engines(self):
+    supported_engine_names = []
+    if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
+      supported_engine_names.append('mlx')
+      supported_engine_names.append('tinygrad')
+    else:
+      supported_engine_names.append('tinygrad')
+    return supported_engine_names
+
+  async def broadcast_supported_engines(self, supported_engines_names: List[str]):
+    status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
+    await self.broadcast_opaque_status("", status_message)
+
+  def get_topology_inference_engines(self) -> List[List[str]]:
+    return self.topology_inference_engines_pool
+  
+  async def process_inference_result(
+    self,
+    shard,
+    result: np.ndarray,
+    request_id: Optional[str] = None,
+  ):
+    if request_id not in self.buffered_token_output:
+      self.buffered_token_output[request_id] = ([], False)
+    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    if shard.is_last_layer() and not is_finished:
+      token = await self.inference_engine.sample(result)
+      await self.inference_engine.ensure_shard(shard)
+      self.buffered_token_output[request_id][0].append(token.item())
+      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+      forward = token.reshape(1, -1)
+      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
+    else:
+      forward = result
+
+    if is_finished:
+      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+    else:
+      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+
+    return np.array(self.buffered_token_output[request_id][0])
+
+  async def process_prompt(
+    self,
+    base_shard: Shard,
+    prompt: str,
+    request_id: Optional[str] = None,
+  ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(
       self.broadcast_opaque_status(
       self.broadcast_opaque_status(
@@ -88,14 +147,12 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
           "prompt": prompt,
-          "image_str": image_str,
-          "inference_state": inference_state,
           "request_id": request_id,
           "request_id": request_id,
         }),
         }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
+    resp = await self._process_prompt(base_shard, prompt, request_id)
     end_time = time.perf_counter_ns()
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
     asyncio.create_task(
@@ -108,8 +165,6 @@ class StandardNode(Node):
           "base_shard": base_shard.to_dict(),
           "base_shard": base_shard.to_dict(),
           "shard": shard.to_dict(),
           "shard": shard.to_dict(),
           "prompt": prompt,
           "prompt": prompt,
-          "image_str": image_str,
-          "inference_state": inference_state,
           "request_id": request_id,
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
           "elapsed_time_ns": elapsed_time_ns,
           "result_size": resp.size if resp is not None else 0,
           "result_size": resp.size if resp is not None else 0,
@@ -118,42 +173,26 @@ class StandardNode(Node):
     )
     )
     return resp
     return resp
 
 
-  async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
+  async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
     if request_id is None:
       request_id = str(uuid.uuid4())
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
 
 
-    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
-    if shard.start_layer != 0:
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
-      await self.forward_to_next_shard(shard, prompt, request_id, image_str=image_str, inference_state=inference_state)
-      return
-
-    result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
-    is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-    asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-    if result.size == 1:
-      self.buffered_token_output[request_id][0].append(result.item())
-      self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-
-    if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-    if not is_finished:
-      asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, image_str=image_str, inference_state=inference_state))
-
-    return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+    if not shard.is_first_layer():
+      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+      resp = await self.forward_prompt(shard, prompt, request_id, 0)
+      return None
+    else:
+      result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
+      ret = await self.process_inference_result(shard, result, request_id) 
+      return result
 
 
   async def process_tensor(
   async def process_tensor(
     self,
     self,
     base_shard: Shard,
     base_shard: Shard,
     tensor: np.ndarray,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ) -> Optional[np.ndarray]:
   ) -> Optional[np.ndarray]:
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
     asyncio.create_task(
@@ -168,12 +207,11 @@ class StandardNode(Node):
           "tensor_size": tensor.size,
           "tensor_size": tensor.size,
           "tensor_shape": tensor.shape,
           "tensor_shape": tensor.shape,
           "request_id": request_id,
           "request_id": request_id,
-          "inference_state": inference_state,
         }),
         }),
       )
       )
     )
     )
     start_time = time.perf_counter_ns()
     start_time = time.perf_counter_ns()
-    resp = await self._process_tensor(shard, tensor, request_id, inference_state)
+    resp = await self._process_tensor(shard, tensor, request_id)
     end_time = time.perf_counter_ns()
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
     asyncio.create_task(
@@ -198,84 +236,77 @@ class StandardNode(Node):
     base_shard: Shard,
     base_shard: Shard,
     tensor: np.ndarray,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
     request_id: Optional[str] = None,
-    inference_state: Optional[str] = None,
   ) -> Optional[np.ndarray]:
   ) -> Optional[np.ndarray]:
     if request_id is None:
     if request_id is None:
       request_id = str(uuid.uuid4())
       request_id = str(uuid.uuid4())
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
 
 
+    if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
     try:
     try:
-      if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-      result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
-      is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if is_finished:
-        self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))  # TODO: this is n^2 communication complexity
-
-      if result.size == 1:  # we got a new token out
-        self.buffered_token_output[request_id][0].append(result.item())
-        self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-
-      if not is_finished:
-        asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
-
-      return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
+      result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
+      ret = await self.process_inference_result(shard, result, request_id) 
+      return ret
     except Exception as e:
     except Exception as e:
       print(f"Error processing tensor for shard {shard}: {e}")
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
       traceback.print_exc()
       return None
       return None
 
 
-  async def forward_to_next_shard(
+  async def forward_prompt(
     self,
     self,
     base_shard: Shard,
     base_shard: Shard,
-    tensor_or_prompt: Union[np.ndarray, str],
+    prompt: str,
     request_id: str,
     request_id: str,
-    image_str: Optional[str] = None,
-    inference_state: Optional[str] = None,
+    target_index: int,
   ) -> None:
   ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_prompt(next_shard, prompt, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+  
+  async def forward_tensor(
+    self,
+    base_shard: Shard,
+    tensor: np.ndarray,
+    request_id: str,
+    target_index: int,
+  ) -> None:
+    if DEBUG >= 1: print(f"target partition index: {target_index}")
+    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+    next_shard = self.get_current_shard(base_shard, target_index)
+    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
+    if target_id == self.id:
+      await self.process_tensor(next_shard, tensor, request_id)
+    else:
+      target_peer = next((p for p in self.peers if p.id() == target_id), None)
+      if not target_peer:
+        raise ValueError(f"Peer for {target_index} not found")
+      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
+
+  def get_partition_index(self, offset: int = 0):
     if not self.partitioning_strategy:
     if not self.partitioning_strategy:
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
       if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
-      return
-    shard = self.get_current_shard(base_shard)
-
+      return None
     partitions = self.partitioning_strategy.partition(self.topology)
     partitions = self.partitioning_strategy.partition(self.topology)
-    shards = map_partitions_to_shards(self.partitioning_strategy.partition(self.topology), base_shard.n_layers, base_shard.model_id)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
     current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-    if DEBUG >= 1: print(f"Current partition index: {current_partition_index}")
-    if current_partition_index is not None:
-      next_partition_index = (current_partition_index+1) % len(partitions)
-      next_partition: Partition = partitions[next_partition_index]
-      next_shard = shards[next_partition_index]
-      if DEBUG >= 2: print(f"Computed next from: {shard}, {self.topology}. Next partition: {next_partition}")
-
-      if next_partition.node_id == self.id:
-        if isinstance(tensor_or_prompt, np.ndarray):
-          await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
-        else:
-          await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
-        return
-
-      target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {next_partition} not found")
-
-      if DEBUG >= 1: print(f"Sending tensor_or_prompt to {target_peer.id()}: {tensor_or_prompt}")
-
-      if isinstance(tensor_or_prompt, np.ndarray):
-        await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
-      else:
-        await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
+    if current_partition_index is None:
+      raise ValueError(f"No current partition found for node: {self.id}")
+    return (current_partition_index + offset) % len(partitions)
 
 
-  def get_current_shard(self, base_shard: Shard) -> Shard:
+  def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
+    if index is None:
+      index = self.get_partition_index()
     partitions = self.partitioning_strategy.partition(self.topology)
     partitions = self.partitioning_strategy.partition(self.topology)
     shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
     shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
-    current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
-    if current_partition_index is None:
-      raise ValueError(f"No current partition found for node: {self.id}")
-    return shards[current_partition_index]
+    return shards[index]
 
 
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
   async def update_peers(self, wait_for_peers: int = 0) -> bool:
     next_peers = await self.discovery.discover_peers(wait_for_peers)
     next_peers = await self.discovery.discover_peers(wait_for_peers)
@@ -283,20 +314,16 @@ class StandardNode(Node):
     next_peer_ids = {peer.id() for peer in next_peers}
     next_peer_ids = {peer.id() for peer in next_peers}
     peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
     peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
     peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
     peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
-    peers_updated = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
-    peers_unchanged = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
+    peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
+    peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
     peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
     peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
     peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
     peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
 
 
     def _pretty(peers: List[PeerHandle]) -> List[str]:
     def _pretty(peers: List[PeerHandle]) -> List[str]:
       return [f"{peer.id()}@{peer.addr()}" for peer in peers]
       return [f"{peer.id()}@{peer.addr()}" for peer in peers]
-    if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
+
+    if DEBUG >= 2:
+      print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
 
 
     async def disconnect_with_timeout(peer, timeout=5):
     async def disconnect_with_timeout(peer, timeout=5):
       try:
       try:
@@ -316,14 +343,8 @@ class StandardNode(Node):
         traceback.print_exc()
         traceback.print_exc()
         return False
         return False
 
 
-    disconnect_results = await asyncio.gather(
-      *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
-      return_exceptions=True
-    )
-    connect_results = await asyncio.gather(
-      *(connect_with_timeout(peer) for peer in peers_to_connect),
-      return_exceptions=True
-    )
+    disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
+    connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
 
 
     successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
     successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
     failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
     failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
@@ -338,6 +359,12 @@ class StandardNode(Node):
     self.peers = next_peers
     self.peers = next_peers
     return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
     return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
 
 
+  async def select_best_inference_engine(self):
+    supported_engines = self.get_supported_inference_engines()
+    await self.broadcast_supported_engines(supported_engines)
+    if len(self.get_topology_inference_engines()):
+      self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
+
   async def periodic_topology_collection(self, interval: int):
   async def periodic_topology_collection(self, interval: int):
     while True:
     while True:
       await asyncio.sleep(interval)
       await asyncio.sleep(interval)
@@ -346,6 +373,7 @@ class StandardNode(Node):
         if DEBUG >= 2: print(f"{did_peers_change=}")
         if DEBUG >= 2: print(f"{did_peers_change=}")
         if did_peers_change:
         if did_peers_change:
           await self.collect_topology()
           await self.collect_topology()
+          await self.select_best_inference_engine()
       except Exception as e:
       except Exception as e:
         print(f"Error collecting topology: {e}")
         print(f"Error collecting topology: {e}")
         traceback.print_exc()
         traceback.print_exc()
@@ -382,6 +410,7 @@ class StandardNode(Node):
         self.topology.merge(other_topology)
         self.topology.merge(other_topology)
       except Exception as e:
       except Exception as e:
         print(f"Error collecting topology from {peer.id()}: {e}")
         print(f"Error collecting topology from {peer.id()}: {e}")
+        traceback.print_exc()
 
 
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
     self.topology = next_topology
@@ -400,7 +429,7 @@ class StandardNode(Node):
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
   def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
     self.on_token.trigger_all(request_id, tokens, is_finished)
     self.on_token.trigger_all(request_id, tokens, is_finished)
-
+  
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
   async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
     async def send_result_to_peer(peer):
     async def send_result_to_peer(peer):
       try:
       try:

+ 150 - 64
exo/tinychat/index.css

@@ -1,31 +1,11 @@
 /* define colors */
 /* define colors */
 :root {
 :root {
-  --primary-color: #a52e4d;
-  --primary-color-transparent: #a52e4d66;
-  --secondary-color: #228039;
-  --secondary-color-transparent: #22803966;
-
+  --primary-color: #fff;
+  --secondary-color: #2a2a2a;
+  --secondary-color-transparent: #ffffff66;
+  --primary-bg-color: #1a1a1a;
+  --foreground-color: #f0f0f0;
   --red-color: #a52e4d;
   --red-color: #a52e4d;
-  --green-color: #228039;
-  --silver-color: #88808e;
-}
-@media(prefers-color-scheme: light) {
-  :root {
-    --primary-bg-color: #f0f0f0;
-    --secondary-bg-color: #eeeeee;
-    --tertiary-bg-color: #dddddd;
-    --foreground-color: #111111;
-    --accent-color: #000000;
-  }
-}
-@media(prefers-color-scheme: dark) {
-  :root {
-    --primary-bg-color: #111111;
-    --secondary-bg-color: #131313;
-    --tertiary-bg-color: #232323;
-    --foreground-color: #f0f0f0;
-    --accent-color: #aaaaaa;
-  }
 }
 }
 
 
 main {
 main {
@@ -81,7 +61,11 @@ main {
   top: 0;
   top: 0;
   position: absolute;
   position: absolute;
 
 
-  background: linear-gradient(180deg, var(--primary-bg-color) 0%, transparent 100%);
+  background: linear-gradient(
+    180deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
 }
 }
 .histories-end {
 .histories-end {
   height: 3rem;
   height: 3rem;
@@ -91,7 +75,11 @@ main {
   bottom: 0;
   bottom: 0;
   position: absolute;
   position: absolute;
 
 
-  background: linear-gradient(0deg, var(--primary-bg-color) 0%, transparent 100%);
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 0%,
+    transparent 100%
+  );
 }
 }
 
 
 .history {
 .history {
@@ -99,7 +87,7 @@ main {
   width: 100%;
   width: 100%;
   max-width: 40rem;
   max-width: 40rem;
 
 
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   border-radius: 10px;
   border-radius: 10px;
   border-left: 2px solid var(--primary-color);
   border-left: 2px solid var(--primary-color);
 
 
@@ -109,7 +97,7 @@ main {
   opacity: var(--opacity, 1);
   opacity: var(--opacity, 1);
 }
 }
 .history:hover {
 .history:hover {
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
 }
 }
 
 
 .history-delete-button {
 .history-delete-button {
@@ -120,14 +108,14 @@ main {
   margin: 0;
   margin: 0;
   outline: none;
   outline: none;
   border: none;
   border: none;
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   color: var(--foreground-color);
   border-radius: 0 0 0 10px;
   border-radius: 0 0 0 10px;
   cursor: pointer;
   cursor: pointer;
   transition: 0.2s;
   transition: 0.2s;
 }
 }
 .history-delete-button:hover {
 .history-delete-button:hover {
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   padding: 0.75rem;
   padding: 0.75rem;
 }
 }
 
 
@@ -135,6 +123,7 @@ main {
   overflow-y: auto;
   overflow-y: auto;
   height: 100%;
   height: 100%;
   width: 100%;
   width: 100%;
+  max-width: 1200px;
 
 
   display: flex;
   display: flex;
   flex-direction: column;
   flex-direction: column;
@@ -145,44 +134,113 @@ main {
 }
 }
 
 
 .message {
 .message {
-  width: 96%;
-  max-width: 80rem;
-
-  display: grid;
-
-  background-color: var(--secondary-bg-color);
+  max-width: 75%;
   padding: 0.5rem 1rem;
   padding: 0.5rem 1rem;
-  border-radius: 10px;
+  border-radius: 20px;
 }
 }
 .message-role-assistant {
 .message-role-assistant {
-  border-bottom: 2px solid var(--primary-color);
-  border-left: 2px solid var(--primary-color);
-  box-shadow: -10px 10px 20px 2px var(--primary-color-transparent);
+  background-color: var(--secondary-color);
+  margin-right: auto;
+  color: #fff;
 }
 }
 .message-role-user {
 .message-role-user {
-  border-bottom: 2px solid var(--secondary-color);
-  border-right: 2px solid var(--secondary-color);
-  box-shadow: 10px 10px 20px 2px var(--secondary-color-transparent);
+  margin-left: auto;
+  background-color: var(--primary-color);
+  color: #000;
 }
 }
-.download-progress{
-  margin-bottom: 20em;
+.download-progress {
+  margin-bottom: 12em;
+  overflow-y: auto;
+  min-height: 350px;
+  padding: 2rem;
 }
 }
 .message > pre {
 .message > pre {
   white-space: pre-wrap;
   white-space: pre-wrap;
 }
 }
 
 
+.progress-bar-container {
+  width: 100%;
+  background-color: #e0e0e0;
+  border-radius: 4px;
+  margin: 10px 0;
+}
+.progress-bar {
+  height: 20px;
+  border-radius: 4px;
+  transition: width 0.5s ease-in-out;
+}
+.progress-bar.complete {
+  background-color: #4CAF50;
+}
+.progress-bar.in-progress {
+  background-color: #2196F3;
+}
+
 .toast {
 .toast {
-    width: 100%; /* Take up the full width of the page */
-    background-color: #fc2a2a; /* Dark background color */
-    color: #fff; /* White text color */
-    text-align: center; /* Centered text */
-    border-radius: 2px; /* Rounded borders */
-    padding: 16px; /* Padding */
-    position: fixed; /* Sit on top of the screen content */
-    z-index: 9999; /* Add a z-index if needed */
-    top: 0; /* Position at the top of the page */
-    left: 0; /* Extend from the left edge */
-    right: 0; /* Extend to the right edge */
+    width: 100%;
+    background-color: #fc2a2a;
+    color: #fff;
+    text-align: left;
+    border-radius: 2px;
+    padding: 16px;
+    position: fixed;
+    z-index: 9999;
+    top: 0;
+    left: 0;
+    right: 0;
+    display: flex;
+    flex-direction: column;
+    white-space: pre-wrap;
+    font-family: monospace;
+}
+
+.toast-header {
+    display: flex;
+    justify-content: space-between;
+    align-items: center;
+    width: 100%;
+}
+
+.toast-error-message {
+    flex-grow: 1;
+}
+
+.toast-header-buttons {
+    display: flex;
+    align-items: center;
+    gap: 16px;
+    margin-left: 24px;
+}
+
+.toast-expand-button {
+    background: none;
+    border: none;
+    color: white;
+    padding: 4px;
+    cursor: pointer;
+    font-size: 1em;
+}
+
+.toast-close-button {
+    background: none;
+    border: none;
+    color: white;
+    padding: 4px;
+    cursor: pointer;
+    font-size: 1.2em;
+    line-height: 1;
+}
+
+.toast-expand-button:hover,
+.toast-close-button:hover {
+    opacity: 0.8;
+}
+
+.toast-content {
+    margin-top: 10px;
+    padding: 10px;
+    background-color: rgba(0, 0, 0, 0.2);
+    border-radius: 4px;
 }
 }
 
 
 .hljs {
 .hljs {
@@ -201,14 +259,14 @@ main {
   margin: 0;
   margin: 0;
   outline: none;
   outline: none;
   border: none;
   border: none;
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   color: var(--foreground-color);
   border-radius: 0 0 0 10px;
   border-radius: 0 0 0 10px;
   cursor: pointer;
   cursor: pointer;
   transition: 0.2s;
   transition: 0.2s;
 }
 }
 .clipboard-button:hover {
 .clipboard-button:hover {
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   padding: 0.75rem;
   padding: 0.75rem;
 }
 }
 
 
@@ -217,9 +275,14 @@ main {
   bottom: 0;
   bottom: 0;
 
 
   /* linear gradient from background-color to transparent on the top */
   /* linear gradient from background-color to transparent on the top */
-  background: linear-gradient(0deg, var(--primary-bg-color) 55%, transparent 100%);
+  background: linear-gradient(
+    0deg,
+    var(--primary-bg-color) 55%,
+    transparent 100%
+  );
 
 
   width: 100%;
   width: 100%;
+  max-width: 1200px;
   display: flex;
   display: flex;
   flex-direction: column;
   flex-direction: column;
   justify-content: center;
   justify-content: center;
@@ -266,7 +329,7 @@ main {
   min-height: 3rem;
   min-height: 3rem;
   max-height: 8rem;
   max-height: 8rem;
 
 
-  background-color: var(--tertiary-bg-color);
+  background-color: var(--secondary-color);
   color: var(--foreground-color);
   color: var(--foreground-color);
   border-radius: 10px;
   border-radius: 10px;
   border: none;
   border: none;
@@ -278,8 +341,8 @@ main {
   height: 3rem;
   height: 3rem;
   width: 4rem;
   width: 4rem;
 
 
-  background-color: var(--secondary-color);
-  color: var(--foreground-color);
+  background-color: var(--primary-color);
+  color: var(--secondary-color);
   border-radius: 10px;
   border-radius: 10px;
   padding: 0.5rem;
   padding: 0.5rem;
   cursor: pointer;
   cursor: pointer;
@@ -288,7 +351,7 @@ main {
   background-color: var(--secondary-color-transparent);
   background-color: var(--secondary-color-transparent);
 }
 }
 .input-button:disabled {
 .input-button:disabled {
-  background-color: var(--secondary-bg-color);
+  background-color: var(--secondary-color);
   cursor: not-allowed;
   cursor: not-allowed;
 }
 }
 
 
@@ -395,4 +458,27 @@ p {
   max-width: 100%;
   max-width: 100%;
   max-height: 100%;
   max-height: 100%;
   object-fit: contain;
   object-fit: contain;
+}
+
+.clear-history-button {
+  background-color: var(--red-color);
+  color: white;
+  padding: 10px 20px;
+  border-radius: 5px;
+  display: flex;
+  align-items: center;
+  gap: 8px;
+  transition: all 0.3s ease;
+  margin: 1rem auto;
+  border: none;
+  cursor: pointer;
+}
+
+.clear-history-button:hover {
+  opacity: 0.8;
+  transform: scale(1.05);
+}
+
+.clear-history-button i {
+  font-size: 14px;
 }
 }

+ 51 - 32
exo/tinychat/index.html

@@ -26,31 +26,27 @@
 <body>
 <body>
 <main x-data="state" x-init="console.log(endpoint)">
 <main x-data="state" x-init="console.log(endpoint)">
      <!-- Error Toast -->
      <!-- Error Toast -->
-    <div x-show="errorMessage" x-transition.opacity x-text="errorMessage" class="toast">
+    <div x-show="errorMessage" x-transition.opacity class="toast">
+        <div class="toast-header">
+            <span class="toast-error-message" x-text="errorMessage.basic"></span>
+            <div class="toast-header-buttons">
+                <button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }" 
+                        class="toast-expand-button" 
+                        x-show="errorMessage.stack">
+                    <span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
+                </button>
+                <button @click="errorMessage = null; errorExpanded = false;" class="toast-close-button">
+                    <i class="fas fa-times"></i>
+                </button>
+            </div>
+        </div>
+        <div class="toast-content" x-show="errorExpanded" x-transition>
+            <span x-text="errorMessage.stack"></span>
+        </div>
     </div>
     </div>
 <div class="model-selector">
 <div class="model-selector">
-<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel">
-<option selected="" value="llama-3.2-1b">Llama 3.2 1B</option>
-<option value="llama-3.2-3b">Llama 3.2 3B</option>
-<option value="llama-3.1-8b">Llama 3.1 8B</option>
-<option value="llama-3.1-70b">Llama 3.1 70B</option>
-<option value="llama-3.1-70b-bf16">Llama 3.1 70B (BF16)</option>
-<option value="llama-3.1-405b">Llama 3.1 405B</option>
-<option value="llama-3-8b">Llama 3 8B</option>
-<option value="llama-3-70b">Llama 3 70B</option>
-<option value="mistral-nemo">Mistral Nemo</option>
-<option value="mistral-large">Mistral Large</option>
-<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
-<option value="deepseek-coder-v2.5">Deepseek Coder V2.5</option>
-<option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
-<option value="qwen-2.5-coder-1.5b">Qwen 2.5 Coder 1.5B</option>
-<option value="qwen-2.5-coder-7b">Qwen 2.5 Coder 7B</option>
-<option value="qwen-2.5-7b">Qwen 2.5 7B</option>
-<option value="qwen-2.5-math-7b">Qwen 2.5 7B (Math)</option>
-<option value="qwen-2.5-14b">Qwen 2.5 14B</option>
-<option value="qwen-2.5-72b">Qwen 2.5 72B</option>
-<option value="qwen-2.5-math-72b">Qwen 2.5 72B (Math)</option>
-</select>
+  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
+  </select>
 </div>
 </div>
 <div @popstate.window="
 <div @popstate.window="
       if (home === 2) {
       if (home === 2) {
@@ -66,6 +62,13 @@
       if (home === -1) setTimeout(() =&gt; home = 0, 100);
       if (home === -1) setTimeout(() =&gt; home = 0, 100);
     " x-show="home === 0" x-transition="">
     " x-show="home === 0" x-transition="">
 <h1 class="title megrim-regular">tinychat</h1>
 <h1 class="title megrim-regular">tinychat</h1>
+<template x-if="histories.length">
+  <button 
+    @click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();" 
+    class="clear-history-button">
+    <i class="fas fa-trash"></i> Clear All History
+  </button>
+</template>
 <div class="histories-container-container">
 <div class="histories-container-container">
 <template x-if="histories.length">
 <template x-if="histories.length">
 <div class="histories-start"></div>
 <div class="histories-start"></div>
@@ -151,16 +154,32 @@
 </div>
 </div>
 
 
 <!-- Download Progress Section -->
 <!-- Download Progress Section -->
-<template x-if="downloadProgress">
-<div class="download-progress message message-role-assistant">
-  <h2>Download Progress</h2>
-  <div class="download-progress-node">
-    <p><strong>Model:</strong> <span x-text="downloadProgress.repo_id + '@' + downloadProgress.repo_revision"></span></p>
-    <p><strong>Progress:</strong> <span x-text="`${downloadProgress.downloaded_bytes_display} / ${downloadProgress.total_bytes_display} (${downloadProgress.percentage}%)`"></span></p>
-    <p><strong>Speed:</strong> <span x-text="downloadProgress.overall_speed_display || 'N/A'"></span></p>
-    <p><strong>ETA:</strong> <span x-text="downloadProgress.overall_eta_display || 'N/A'"></span></p>
+<template x-if="downloadProgress && downloadProgress.length > 0">
+  <div class="download-progress message message-role-assistant">
+    <h2>Download Progress</h2>
+    <br>
+    <template x-for="(progress, index) in downloadProgress" :key="index">
+      <div class="download-progress-node">
+        <br>
+        <h3 x-text="`Download ${index + 1}`"></h3>
+        <p><strong>Model:</strong> <span x-text="progress.repo_id + '@' + progress.repo_revision"></span></p>
+        <p><strong>Status:</strong> <span x-text="progress.status"></span></p>
+        <div class="progress-bar-container">
+          <div class="progress-bar" 
+               :class="progress.isComplete ? 'complete' : 'in-progress'"
+               :style="`width: ${progress.percentage}%;`">
+          </div>
+        </div>
+        <template x-if="!progress.isComplete">
+          <div>
+            <p><strong>Progress:</strong> <span x-text="`${progress.downloaded_bytes_display} / ${progress.total_bytes_display} (${progress.percentage}%)`"></span></p>
+            <p><strong>Speed:</strong> <span x-text="progress.overall_speed_display || 'N/A'"></span></p>
+            <p><strong>ETA:</strong> <span x-text="progress.overall_eta_display || 'N/A'"></span></p>
+          </div>
+        </template>
+      </div>
+    </template>
   </div>
   </div>
-</div>
 </template>
 </template>
 
 
 
 

+ 149 - 74
exo/tinychat/index.js

@@ -4,8 +4,8 @@ document.addEventListener("alpine:init", () => {
     cstate: {
     cstate: {
       time: null,
       time: null,
       messages: [],
       messages: [],
-      selectedModel: 'llama-3.1-8b',
-    },
+      selectedModel: 'llama-3.2-1b',
+    },    
 
 
     // historical state
     // historical state
     histories: JSON.parse(localStorage.getItem("histories")) || [],
     histories: JSON.parse(localStorage.getItem("histories")) || [],
@@ -14,6 +14,8 @@ document.addEventListener("alpine:init", () => {
     generating: false,
     generating: false,
     endpoint: `${window.location.origin}/v1`,
     endpoint: `${window.location.origin}/v1`,
     errorMessage: null,
     errorMessage: null,
+    errorExpanded: false,
+    errorTimeout: null,
 
 
     // performance tracking
     // performance tracking
     time_till_first: 0,
     time_till_first: 0,
@@ -32,8 +34,10 @@ document.addEventListener("alpine:init", () => {
 
 
     init() {
     init() {
       // Clean up any pending messages
       // Clean up any pending messages
-      this.pendingMessage = null;
       localStorage.removeItem("pendingMessage");
       localStorage.removeItem("pendingMessage");
+
+      // Start polling for download progress
+      this.startDownloadProgressPolling();
     },
     },
 
 
     removeHistory(cstate) {
     removeHistory(cstate) {
@@ -45,6 +49,12 @@ document.addEventListener("alpine:init", () => {
         localStorage.setItem("histories", JSON.stringify(this.histories));
         localStorage.setItem("histories", JSON.stringify(this.histories));
       }
       }
     },
     },
+
+    clearAllHistory() {
+      this.histories = [];
+      localStorage.setItem("histories", JSON.stringify([]));
+    },
+
     // Utility functions
     // Utility functions
     formatBytes(bytes) {
     formatBytes(bytes) {
       if (bytes === 0) return '0 B';
       if (bytes === 0) return '0 B';
@@ -64,6 +74,56 @@ document.addEventListener("alpine:init", () => {
       return `${s}s`;
       return `${s}s`;
     },
     },
 
 
+    async populateSelector() {
+      try {
+        const response = await fetch(`${window.location.origin}/modelpool`);
+        const responseText = await response.text(); // Get raw response text first
+        
+        if (!response.ok) {
+          throw new Error(`HTTP error! status: ${response.status}`);
+        }
+        
+        // Try to parse the response text
+        let responseJson;
+        try {
+          responseJson = JSON.parse(responseText);
+        } catch (parseError) {
+          console.error('Failed to parse JSON:', parseError);
+          throw new Error(`Invalid JSON response: ${responseText}`);
+        }
+
+        const sel = document.querySelector(".model-select");
+        if (!sel) {
+          throw new Error("Could not find model selector element");
+        }
+
+        // Clear the current options and add new ones
+        sel.innerHTML = '';
+          
+        const modelDict = responseJson["model pool"];
+        if (!modelDict) {
+          throw new Error("Response missing 'model pool' property");
+        }
+
+        Object.entries(modelDict).forEach(([key, value]) => {
+          const opt = document.createElement("option");
+          opt.value = key;
+          opt.textContent = value;
+          sel.appendChild(opt);
+        });
+
+        // Set initial value to the first model
+        const firstKey = Object.keys(modelDict)[0];
+        if (firstKey) {
+          sel.value = firstKey;
+          this.cstate.selectedModel = firstKey;
+        }
+      } catch (error) {
+        console.error("Error populating model selector:", error);
+        this.errorMessage = `Failed to load models: ${error.message}`;
+      }
+    },
+
     async handleImageUpload(event) {
     async handleImageUpload(event) {
       const file = event.target.files[0];
       const file = event.target.files[0];
       if (file) {
       if (file) {
@@ -105,38 +165,34 @@ document.addEventListener("alpine:init", () => {
         el.style.height = "auto";
         el.style.height = "auto";
         el.style.height = el.scrollHeight + "px";
         el.style.height = el.scrollHeight + "px";
 
 
-        // Proceed to handle the message
+        localStorage.setItem("pendingMessage", value);
         this.processMessage(value);
         this.processMessage(value);
-
-        // Start polling for download progress
-        this.startDownloadProgressPolling();
-
-        // Delay the check for downloadProgress by 8 seconds without blocking execution
-        setTimeout(async () => {
-          this.pendingMessageHandler(value);
-        }, 8000);
-
       } catch (error) {
       } catch (error) {
-        console.error('error', error)
-        this.errorMessage = error.message || 'Errore durante l\'invio del messaggio.';
-        setTimeout(() => {
-          this.errorMessage = null;
-        }, 5 * 1000)
-      }
-    },
+        console.error('error', error);
+        const errorDetails = {
+            message: error.message || 'Unknown error',
+            stack: error.stack,
+            name: error.name || 'Error'
+        };
+        
+        this.errorMessage = {
+            basic: `${errorDetails.name}: ${errorDetails.message}`,
+            stack: errorDetails.stack
+        };
 
 
-    async pendingMessageHandler(value) {
-      console.log("Pending message handler called");
-      // Check if download is in progress
-      if (this.downloadProgress && this.downloadProgress.status !== "complete") {
-        // Save the message in pendingMessage
-        this.pendingMessage = value;
-        localStorage.setItem("pendingMessage", value);
-        console.log("Pending message saved:", localStorage.getItem("pendingMessage"));
-        // Inform the user
-        this.cstate.messages.push({ role: "system", content: "Download is in progress. Your message will be processed once the download completes." });
-        this.generating = false; // Reset generating
-        return;
+        // Clear any existing timeout
+        if (this.errorTimeout) {
+            clearTimeout(this.errorTimeout);
+        }
+
+        // Only set the timeout if the error details aren't expanded
+        if (!this.errorExpanded) {
+            this.errorTimeout = setTimeout(() => {
+                this.errorMessage = null;
+                this.errorExpanded = false;
+            }, 30 * 1000);
+        }
+        this.generating = false;
       }
       }
     },
     },
 
 
@@ -252,11 +308,30 @@ document.addEventListener("alpine:init", () => {
           console.error("Failed to save histories to localStorage:", error);
           console.error("Failed to save histories to localStorage:", error);
         }
         }
       } catch (error) {
       } catch (error) {
-        console.error('error', error)
-        this.errorMessage = error;
-        setTimeout(() => {
-          this.errorMessage = null;
-        }, 5 * 1000)
+        console.error('error', error);
+        const errorDetails = {
+            message: error.message || 'Unknown error',
+            stack: error.stack,
+            name: error.name || 'Error'
+        };
+        
+        this.errorMessage = {
+            basic: `${errorDetails.name}: ${errorDetails.message}`,
+            stack: errorDetails.stack
+        };
+
+        // Clear any existing timeout
+        if (this.errorTimeout) {
+            clearTimeout(this.errorTimeout);
+        }
+
+        // Only set the timeout if the error details aren't expanded
+        if (!this.errorExpanded) {
+            this.errorTimeout = setTimeout(() => {
+                this.errorMessage = null;
+                this.errorExpanded = false;
+            }, 30 * 1000);
+        }
       } finally {
       } finally {
         this.generating = false;
         this.generating = false;
       }
       }
@@ -325,54 +400,60 @@ document.addEventListener("alpine:init", () => {
 
 
     async fetchDownloadProgress() {
     async fetchDownloadProgress() {
       try {
       try {
-        console.log("fetching download progress");
-        await new Promise(resolve => setTimeout(resolve, 4000)); // Necessary delay
         const response = await fetch(`${this.endpoint}/download/progress`);
         const response = await fetch(`${this.endpoint}/download/progress`);
         if (response.ok) {
         if (response.ok) {
           const data = await response.json();
           const data = await response.json();
           const progressArray = Object.values(data);
           const progressArray = Object.values(data);
           if (progressArray.length > 0) {
           if (progressArray.length > 0) {
-            const progress = progressArray[0];
-            // Check if download is complete
-            if (progress.status === "complete" || progress.status === "failed") {
-              this.downloadProgress = null; // Hide the progress section
-              // Stop polling
-              this.stopDownloadProgressPolling();
-
+            this.downloadProgress = progressArray.map(progress => {
+              // Check if download is complete
               if (progress.status === "complete") {
               if (progress.status === "complete") {
-                // Download is complete
-                // Check for pendingMessage
-                const savedMessage = localStorage.getItem("pendingMessage");
-                if (savedMessage) {
-                  // Clear pendingMessage
-                  this.pendingMessage = null;
-                  localStorage.removeItem("pendingMessage");
-                  // Call processMessage() with savedMessage
+                return {
+                  ...progress,
+                  isComplete: true,
+                  percentage: 100
+                };
+              } else if (progress.status === "failed") {
+                return {
+                  ...progress,
+                  isComplete: false,
+                  errorMessage: "Download failed"
+                };
+              } else {
+                return {
+                  ...progress,
+                  isComplete: false,
+                  downloaded_bytes_display: this.formatBytes(progress.downloaded_bytes),
+                  total_bytes_display: this.formatBytes(progress.total_bytes),
+                  overall_speed_display: progress.overall_speed ? this.formatBytes(progress.overall_speed) + '/s' : '',
+                  overall_eta_display: progress.overall_eta ? this.formatDuration(progress.overall_eta) : '',
+                  percentage: ((progress.downloaded_bytes / progress.total_bytes) * 100).toFixed(2)
+                };
+              }
+            });
+            const allComplete = this.downloadProgress.every(progress => progress.isComplete);
+            if (allComplete) {
+              // Check for pendingMessage
+              const savedMessage = localStorage.getItem("pendingMessage");
+              if (savedMessage) {
+                // Clear pendingMessage
+                localStorage.removeItem("pendingMessage");
+                // Call processMessage() with savedMessage
+                if (this.lastErrorMessage) {
                   await this.processMessage(savedMessage);
                   await this.processMessage(savedMessage);
                 }
                 }
               }
               }
-            } else {
-              // Compute human-readable strings
-              progress.downloaded_bytes_display = this.formatBytes(progress.downloaded_bytes);
-              progress.total_bytes_display = this.formatBytes(progress.total_bytes);
-              progress.overall_speed_display = progress.overall_speed ? this.formatBytes(progress.overall_speed) + '/s' : '';
-              progress.overall_eta_display = progress.overall_eta ? this.formatDuration(progress.overall_eta) : '';
-              progress.percentage = ((progress.downloaded_bytes / progress.total_bytes) * 100).toFixed(2);
-
-              this.downloadProgress = progress;
+              this.lastErrorMessage = null;
+              this.downloadProgress = null;
             }
             }
           } else {
           } else {
             // No ongoing download
             // No ongoing download
             this.downloadProgress = null;
             this.downloadProgress = null;
-            // Stop polling
-            this.stopDownloadProgressPolling();
           }
           }
         }
         }
       } catch (error) {
       } catch (error) {
         console.error("Error fetching download progress:", error);
         console.error("Error fetching download progress:", error);
         this.downloadProgress = null;
         this.downloadProgress = null;
-        // Stop polling in case of error
-        this.stopDownloadProgressPolling();
       }
       }
     },
     },
 
 
@@ -386,13 +467,6 @@ document.addEventListener("alpine:init", () => {
         this.fetchDownloadProgress();
         this.fetchDownloadProgress();
       }, 1000); // Poll every second
       }, 1000); // Poll every second
     },
     },
-
-    stopDownloadProgressPolling() {
-      if (this.downloadProgressInterval) {
-        clearInterval(this.downloadProgressInterval);
-        this.downloadProgressInterval = null;
-      }
-    },
   }));
   }));
 });
 });
 
 
@@ -549,6 +623,7 @@ function createParser(onParse) {
     }
     }
   }
   }
 }
 }
+
 const BOM = [239, 187, 191];
 const BOM = [239, 187, 191];
 function hasBom(buffer) {
 function hasBom(buffer) {
   return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);
   return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);

+ 44 - 41
exo/tinychat/update_deps.py

@@ -4,49 +4,52 @@ from bs4 import BeautifulSoup
 from urllib.parse import urljoin, urlparse
 from urllib.parse import urljoin, urlparse
 import re
 import re
 
 
+
 def download_file(url, local_path):
 def download_file(url, local_path):
-    response = requests.get(url)
-    if response.status_code == 200:
-        os.makedirs(os.path.dirname(local_path), exist_ok=True)
-        with open(local_path, 'wb') as f:
-            f.write(response.content)
-        print(f"Downloaded: {local_path}")
-    else:
-        print(response.status_code)
-        print(f"Failed to download: {url}")
+  response = requests.get(url)
+  if response.status_code == 200:
+    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+    with open(local_path, 'wb') as f:
+      f.write(response.content)
+    print(f"Downloaded: {local_path}")
+  else:
+    print(response.status_code)
+    print(f"Failed to download: {url}")
+
 
 
 def update_html(html_content, base_url):
 def update_html(html_content, base_url):
-    soup = BeautifulSoup(html_content, 'html.parser')
+  soup = BeautifulSoup(html_content, 'html.parser')
 
 
-    for tag in soup.find_all(['script', 'link']):
-        if tag.has_attr('src'):
-            url = tag['src']
-        elif tag.has_attr('href'):
-            url = tag['href']
-        else:
-            continue
+  for tag in soup.find_all(['script', 'link']):
+    if tag.has_attr('src'):
+      url = tag['src']
+    elif tag.has_attr('href'):
+      url = tag['href']
+    else:
+      continue
+
+    if url.startswith(('http://', 'https://')):
+      full_url = url
+    else:
+      full_url = urljoin(base_url, url)
 
 
-        if url.startswith(('http://', 'https://')):
-            full_url = url
-        else:
-            full_url = urljoin(base_url, url)
+    parsed_url = urlparse(full_url)
+    local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
 
 
-        parsed_url = urlparse(full_url)
-        local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
+    download_file(full_url, local_path)
 
 
-        download_file(full_url, local_path)
+    relative_path = os.path.relpath(local_path, '.')
+    if tag.name == 'script':
+      tag['src'] = "/" + relative_path
+    elif tag.name == 'link':
+      tag['href'] = "/" + relative_path
 
 
-        relative_path = os.path.relpath(local_path, '.')
-        if tag.name == 'script':
-            tag['src'] = "/" + relative_path
-        elif tag.name == 'link':
-            tag['href'] = "/" + relative_path
+  return str(soup)
 
 
-    return str(soup)
 
 
 # Read the HTML file
 # Read the HTML file
 with open('./index.html', 'r') as f:
 with open('./index.html', 'r') as f:
-    html_content = f.read()
+  html_content = f.read()
 
 
 # Update HTML and download files
 # Update HTML and download files
 # updated_html = update_html(html_content, 'https://example.com')
 # updated_html = update_html(html_content, 'https://example.com')
@@ -68,7 +71,7 @@ download_file(css_url, css_output_path)
 
 
 # Parse CSS file for font URLs
 # Parse CSS file for font URLs
 with open(css_output_path, 'r', encoding='utf-8') as f:
 with open(css_output_path, 'r', encoding='utf-8') as f:
-    css_content = f.read()
+  css_content = f.read()
 
 
 # Extract font URLs from the CSS content
 # Extract font URLs from the CSS content
 font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
 font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
@@ -77,14 +80,14 @@ print(f"Found {len(font_urls)} font URLs")
 
 
 # Download font files
 # Download font files
 for font_url in font_urls:
 for font_url in font_urls:
-    font_url = font_url.strip('"\'')
-    if font_url.startswith('../'):
-        font_url = font_url[3:]
+  font_url = font_url.strip('"\'')
+  if font_url.startswith('../'):
+    font_url = font_url[3:]
 
 
-    # Use base_url instead of urljoin to keep the version number
-    full_url = base_url + font_url
-    relative_path = font_url
-    output_path = os.path.join(output_dir, relative_path)
-    download_file(full_url, output_path)
+  # Use base_url instead of urljoin to keep the version number
+  full_url = base_url + font_url
+  relative_path = font_url
+  output_path = os.path.join(output_dir, relative_path)
+  download_file(full_url, output_path)
 
 
-print("Download complete!")
+print("Download complete!")

+ 29 - 16
exo/topology/device_capabilities.py

@@ -1,13 +1,13 @@
+from typing import Any
+from pydantic import BaseModel
 from exo import DEBUG
 from exo import DEBUG
-from dataclasses import dataclass, asdict
 import subprocess
 import subprocess
 import psutil
 import psutil
 
 
 TFLOPS = 1.00
 TFLOPS = 1.00
 
 
 
 
-@dataclass
-class DeviceFlops:
+class DeviceFlops(BaseModel):
   # units of TFLOPS
   # units of TFLOPS
   fp32: float
   fp32: float
   fp16: float
   fp16: float
@@ -17,11 +17,10 @@ class DeviceFlops:
     return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
     return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS"
 
 
   def to_dict(self):
   def to_dict(self):
-    return asdict(self)
+    return self.model_dump()
 
 
 
 
-@dataclass
-class DeviceCapabilities:
+class DeviceCapabilities(BaseModel):
   model: str
   model: str
   chip: str
   chip: str
   memory: int
   memory: int
@@ -30,7 +29,7 @@ class DeviceCapabilities:
   def __str__(self):
   def __str__(self):
     return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
     return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}"
 
 
-  def __post_init__(self):
+  def model_post_init(self, __context: Any) -> None:
     if isinstance(self.flops, dict):
     if isinstance(self.flops, dict):
       self.flops = DeviceFlops(**self.flops)
       self.flops = DeviceFlops(**self.flops)
 
 
@@ -53,9 +52,11 @@ CHIP_FLOPS = {
   "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
   "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS),
   "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
   "Apple M2 Ultra": DeviceFlops(fp32=26.98*TFLOPS, fp16=53.96*TFLOPS, int8=107.92*TFLOPS),
   "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
   "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
-  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
   "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
   "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS),
-  "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS),
+  "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS),
+  "Apple M4": DeviceFlops(fp32=4.26*TFLOPS, fp16=8.52*TFLOPS, int8=17.04*TFLOPS),
+  "Apple M4 Pro": DeviceFlops(fp32=5.72*TFLOPS, fp16=11.44*TFLOPS, int8=22.88*TFLOPS),
+  "Apple M4 Max": DeviceFlops(fp32=18.03*TFLOPS, fp16=36.07*TFLOPS, int8=72.14*TFLOPS),
   ### A chips
   ### A chips
   "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
   "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS),
   "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
   "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS),
@@ -72,6 +73,7 @@ CHIP_FLOPS = {
   "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
   "NVIDIA GEFORCE RTX 4070 SUPER": DeviceFlops(fp32=30.0*TFLOPS, fp16=60.0*TFLOPS, int8=120.0*TFLOPS),
   "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
   "NVIDIA GEFORCE RTX 4070": DeviceFlops(fp32=29.0*TFLOPS, fp16=58.0*TFLOPS, int8=116.0*TFLOPS),
   "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
   "NVIDIA GEFORCE RTX 4060 TI 16GB": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
+  "NVIDIA GEFORCE RTX 4060 TI": DeviceFlops(fp32=22.0*TFLOPS, fp16=44.0*TFLOPS, int8=88.0*TFLOPS),
   # RTX 30 series
   # RTX 30 series
   "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
   "NVIDIA GEFORCE RTX 3050": DeviceFlops(fp32=9.11*TFLOPS, fp16=18.22*TFLOPS, int8=36.44*TFLOPS),
   "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
   "NVIDIA GEFORCE RTX 3060": DeviceFlops(fp32=13.0*TFLOPS, fp16=26.0*TFLOPS, int8=52.0*TFLOPS),
@@ -89,14 +91,24 @@ CHIP_FLOPS = {
   "NVIDIA GEFORCE RTX 2070": DeviceFlops(fp32=7.46*TFLOPS, fp16=14.93*TFLOPS, int8=29.86*TFLOPS),
   "NVIDIA GEFORCE RTX 2070": DeviceFlops(fp32=7.46*TFLOPS, fp16=14.93*TFLOPS, int8=29.86*TFLOPS),
   "NVIDIA GEFORCE RTX 2070 SUPER": DeviceFlops(fp32=9.06*TFLOPS, fp16=18.12*TFLOPS, int8=36.24*TFLOPS),
   "NVIDIA GEFORCE RTX 2070 SUPER": DeviceFlops(fp32=9.06*TFLOPS, fp16=18.12*TFLOPS, int8=36.24*TFLOPS),
   "NVIDIA GEFORCE RTX 2080": DeviceFlops(fp32=10.07*TFLOPS, fp16=20.14*TFLOPS, int8=40.28*TFLOPS),
   "NVIDIA GEFORCE RTX 2080": DeviceFlops(fp32=10.07*TFLOPS, fp16=20.14*TFLOPS, int8=40.28*TFLOPS),
+  "NVIDIA GEFORCE RTX 2080 TI": DeviceFlops(fp32=13.45*TFLOPS, fp16=26.9*TFLOPS, int8=40.28*TFLOPS),
   "NVIDIA GEFORCE RTX 2080 SUPER": DeviceFlops(fp32=11.15*TFLOPS, fp16=22.30*TFLOPS, int8=44.60*TFLOPS),
   "NVIDIA GEFORCE RTX 2080 SUPER": DeviceFlops(fp32=11.15*TFLOPS, fp16=22.30*TFLOPS, int8=44.60*TFLOPS),
   "NVIDIA TITAN RTX": DeviceFlops(fp32=16.31*TFLOPS, fp16=32.62*TFLOPS, int8=65.24*TFLOPS),
   "NVIDIA TITAN RTX": DeviceFlops(fp32=16.31*TFLOPS, fp16=32.62*TFLOPS, int8=65.24*TFLOPS),
-  # QUATRO RTX Ampere series
-  "NVIDIA QUATRO RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS),
-  "NVIDIA QUATRO RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS),
-  "NVIDIA QUATRO RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS),
-  "NVIDIA QUATRO RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS),
-  "NVIDIA QUATRO RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS),
+  # GTX 10 series
+  "NVIDIA GEFORCE GTX 1050 TI": DeviceFlops(fp32=2.0*TFLOPS, fp16=4.0*TFLOPS, int8=8.0*TFLOPS),
+  "NVIDIA GEFORCE GTX 1070": DeviceFlops(fp32=6.463*TFLOPS, fp16=0.101*TFLOPS, int8=25.852*TFLOPS),
+  "NVIDIA GEFORCE GTX 1080": DeviceFlops(fp32=8.873*TFLOPS, fp16=0.138*TFLOPS, int8=35.492*TFLOPS),
+  "NVIDIA GEFORCE GTX 1080 TI": DeviceFlops(fp32=11.34*TFLOPS, fp16=0.177*TFLOPS, int8=45.36*TFLOPS),
+  # GTX 16 series
+  "NVIDIA GeForce GTX 1660 TI": DeviceFlops(fp32=4.8*TFLOPS, fp16=9.6*TFLOPS, int8=19.2*TFLOPS),
+  # QUADRO RTX Ampere series
+  "NVIDIA RTX A2000": DeviceFlops(fp32=7.99*TFLOPS, fp16=7.99*TFLOPS, int8=31.91*TFLOPS),
+  "NVIDIA RTX A4000": DeviceFlops(fp32=19.17*TFLOPS, fp16=19.17*TFLOPS, int8=76.68*TFLOPS),
+  "NVIDIA RTX A4500": DeviceFlops(fp32=23.65*TFLOPS, fp16=23.65*TFLOPS, int8=94.6*TFLOPS),
+  "NVIDIA RTX A5000": DeviceFlops(fp32=27.8*TFLOPS, fp16=27.8*TFLOPS, int8=111.2*TFLOPS),
+  "NVIDIA RTX A6000": DeviceFlops(fp32=38.71*TFLOPS, fp16=38.71*TFLOPS, int8=154.84*TFLOPS),
+  # NVIDIA Ada Lovelace Architecture-Based
+  "NVIDIA RTX 4000 ADA GENERATION": DeviceFlops(fp32=26.7*TFLOPS, fp16=26.7*TFLOPS, int8=258.0*TFLOPS),
   # Common Server GPUs
   # Common Server GPUs
   "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS),
   "NVIDIA A40 48GB PCIE": DeviceFlops(fp32=37.4*TFLOPS, fp16=149.7*TFLOPS, int8=299.3*TFLOPS),
   "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
   "NVIDIA A100 40GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
@@ -176,7 +188,8 @@ def linux_device_capabilities() -> DeviceCapabilities:
 
 
     pynvml.nvmlInit()
     pynvml.nvmlInit()
     handle = pynvml.nvmlDeviceGetHandleByIndex(0)
     handle = pynvml.nvmlDeviceGetHandleByIndex(0)
-    gpu_name = pynvml.nvmlDeviceGetName(handle).upper()
+    gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
+    gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
     gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
     gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
 
 
     if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
     if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")

+ 1 - 1
exo/viz/topology_viz.py

@@ -161,7 +161,7 @@ class TopologyViz:
 
 
     # Calculate total FLOPS and position on the bar
     # Calculate total FLOPS and position on the bar
     total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
     total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
-    bar_pos = (math.tanh(math.cbrt(total_flops)/2.5 - 2) + 1)
+    bar_pos = (math.tanh(total_flops**(1/3)/2.5 - 2) + 1)
 
 
     # Add GPU poor/rich bar
     # Add GPU poor/rich bar
     bar_width = 30
     bar_width = 30

+ 1 - 1
extra/start_openwebui.sh

@@ -1,3 +1,3 @@
-API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):8000}"
+API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
 echo "Using API_ENDPOINT=${API_ENDPOINT}"
 echo "Using API_ENDPOINT=${API_ENDPOINT}"
 docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main
 docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main

+ 1 - 1
format.py

@@ -21,7 +21,7 @@ def run_yapf(target):
 
 
 def main():
 def main():
   if len(sys.argv) < 2:
   if len(sys.argv) < 2:
-    print("Usage: python format.py <directory_or_file>")
+    print("Usage: python3 format.py <directory_or_file> e.g. python3 format.py ./exo")
     sys.exit(1)
     sys.exit(1)
 
 
   target = sys.argv[1]
   target = sys.argv[1]

+ 0 - 5
lint.sh

@@ -1,5 +0,0 @@
-#!/bin/bash
-
-pip3 install -e '.[linting]'
-python3 -m ruff check .
-python3 -m pylint .

+ 0 - 7
pyproject.toml

@@ -1,7 +0,0 @@
-[tool.pylint.format]
-indent-string = '  '
-max-line-length = 200
-
-[tool.autopep8]
-max_line_length = 200
-indent_size = 2

+ 0 - 43
ruff.toml

@@ -1,43 +0,0 @@
-indent-width = 2
-preview = true
-target-version = "py312"
-
-lint.select = [
-  "F",  # Pyflakes
-  "W6",
-  "E71",
-  "E72",
-  "E112",   # no-indented-block
-  "E113",   # unexpected-indentation
-  # "E124",
-  "E203",   # whitespace-before-punctuation
-  "E272",   # multiple-spaces-before-keyword
-  "E303",   # too-many-blank-lines
-  "E304",   # blank-line-after-decorator
-  "E501",   # line-too-long
-  # "E502",
-  "E702",   # multiple-statements-on-one-line-semicolon
-  "E703",   # useless-semicolon
-  "E731",   # lambda-assignment
-  "W191",   # tab-indentation
-  "W291",   # trailing-whitespace
-  "W293",   # blank-line-with-whitespace
-  "UP039",  # unnecessary-class-parentheses
-  "C416",   # unnecessary-comprehension
-  "RET506", # superfluous-else-raise
-  "RET507", # superfluous-else-continue
-  "A",      # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
-  "SIM105", # suppressible-exception
-  "FURB110",# if-exp-instead-of-or-operator
-]
-
-line-length = 200
-
-exclude = [
-  "docs/",
-  "examples/",
-  "extra/",
-  "exo/networking/grpc/node_service_pb2.py",
-  "exo/networking/grpc/node_service_pb2_grpc.py",
-  "exo/helpers.py",
-]

+ 58 - 0
scripts/build_exo.py

@@ -0,0 +1,58 @@
+import site
+import subprocess
+import sys
+import os 
+import pkgutil
+
+def run():
+    site_packages = site.getsitepackages()[0]
+    command = [
+        f"{sys.executable}", "-m", "nuitka", "exo/main.py",
+        "--company-name=exolabs",
+        "--product-name=exo",
+        "--output-dir=dist",
+        "--follow-imports",
+        "--standalone",
+        "--output-filename=exo",
+        "--python-flag=no_site",
+        "--onefile"
+    ]
+
+    if sys.platform == "darwin": 
+        command.extend([
+            "--macos-app-name=exo",
+            "--macos-app-mode=gui",
+            "--macos-app-version=0.0.1",
+            "--macos-signed-app-name=com.exolabs.exo",
+            "--include-distribution-meta=mlx",
+            "--include-module=mlx._reprlib_fix",
+            "--include-module=mlx._os_warning",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib",
+            f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib",
+            "--include-distribution-meta=pygments",
+            "--nofollow-import-to=tinygrad"
+        ])
+        inference_modules = [
+            name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models'])
+        ]
+        for module in inference_modules:
+            command.append(f"--include-module=exo.inference.mlx.models.{module}")
+    elif sys.platform == "win32":  
+        command.extend([
+            "--windows-icon-from-ico=docs/exo-logo-win.ico",
+            "--file-version=0.0.1",
+            "--product-version=0.0.1"
+        ])
+    elif sys.platform.startswith("linux"):  
+        command.extend([
+            "--include-distribution-metadata=pygments",
+            "--linux-icon=docs/exo-rounded.png"
+        ])
+    try:
+        subprocess.run(command, check=True)
+        print("Build completed!")
+    except subprocess.CalledProcessError as e:
+        print(f"An error occurred: {e}")
+
+if __name__ == "__main__":
+    run()

+ 7 - 0
scripts/compile_grpc.sh

@@ -0,0 +1,7 @@
+#!/bin/bash
+source ./install.sh
+pushd exo/networking/grpc
+python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
+sed -i "s/import\ node_service_pb2/from . &/" node_service_pb2_grpc.py
+popd
+

+ 19 - 20
setup.py

@@ -1,49 +1,48 @@
 import sys
 import sys
+import platform
 
 
 from setuptools import find_packages, setup
 from setuptools import find_packages, setup
 
 
 # Base requirements for all platforms
 # Base requirements for all platforms
 install_requires = [
 install_requires = [
-  "aiohttp==3.10.2",
+  "aiohttp==3.10.11",
   "aiohttp_cors==0.7.0",
   "aiohttp_cors==0.7.0",
   "aiofiles==24.1.0",
   "aiofiles==24.1.0",
-  "grpcio==1.64.1",
-  "grpcio-tools==1.64.1",
+  "grpcio==1.68.0",
+  "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "netifaces==0.11.0",
   "numpy==2.0.0",
   "numpy==2.0.0",
+  "nuitka==2.5.1",
+  "nvidia-ml-py==12.560.30",
   "pillow==10.4.0",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
   "prometheus-client==0.20.0",
-  "protobuf==5.27.1",
+  "protobuf==5.28.1",
   "psutil==6.0.0",
   "psutil==6.0.0",
-  "pynvml==11.5.3",
+  "pydantic==2.9.2",
   "requests==2.32.3",
   "requests==2.32.3",
   "rich==13.7.1",
   "rich==13.7.1",
-  "safetensors==0.4.3",
-  "tailscale==0.6.1",
   "tenacity==9.0.0",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
   "tqdm==4.66.4",
-  "transformers==4.43.3",
+  "transformers==4.46.3",
   "uuid==1.30",
   "uuid==1.30",
-  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
+  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
 ]
 ]
 
 
-# Add macOS-specific packages if on Darwin (macOS)
-if sys.platform.startswith("darwin"):
-  install_requires.extend([
-    "mlx==0.18.0",
-    "mlx-lm==0.18.2",
-  ])
-
 extras_require = {
 extras_require = {
-  "linting": [
-    "pylint==3.2.6",
-    "ruff==0.5.5",
-    "mypy==1.11.0",
+  "formatting": [
     "yapf==0.40.2",
     "yapf==0.40.2",
   ],
   ],
+  "apple_silicon": [
+    "mlx==0.20.0",
+    "mlx-lm==0.19.3",
+  ],
 }
 }
 
 
+# Check if running on macOS with Apple Silicon
+if sys.platform.startswith("darwin") and platform.machine() == "arm64":
+  install_requires.extend(extras_require["apple_silicon"])
+
 setup(
 setup(
   name="exo",
   name="exo",
   version="0.0.1",
   version="0.0.1",

+ 1 - 1
test/reconnect.sh

@@ -1,7 +1,7 @@
 #!/bin/bash
 #!/bin/bash
 
 
 echo "Starting node 1"
 echo "Starting node 1"
-DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
+DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
 PID1=$!
 PID1=$!
 echo "Started node 1 PID: $PID1"
 echo "Started node 1 PID: $PID1"
 echo "Starting node 2"
 echo "Starting node 2"

+ 121 - 0
test/test_model_helpers.py

@@ -0,0 +1,121 @@
+import unittest
+from exo.models import get_supported_models, model_cards
+from exo.inference.inference_engine import inference_engine_classes
+from typing import NamedTuple
+
+class TestCase(NamedTuple):
+  name: str
+  engine_lists: list  # Will contain short names, will be mapped to class names
+  expected_models_contains: list
+  min_count: int | None
+  exact_count: int | None
+  max_count: int | None
+
+# Helper function to map short names to class names
+def expand_engine_lists(engine_lists):
+  def map_engine(engine):
+    return inference_engine_classes.get(engine, engine)  # Return original name if not found
+
+  return [[map_engine(engine) for engine in sublist]
+          for sublist in engine_lists]
+
+test_cases = [
+  TestCase(
+    name="single_mlx_engine",
+    engine_lists=[["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="single_tinygrad_engine",
+    engine_lists=[["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="multiple_engines_or",
+    engine_lists=[["mlx", "tinygrad"], ["mlx"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="multiple_engines_all",
+    engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]],
+    expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"],
+    min_count=10,
+    exact_count=None,
+    max_count=None
+  ),
+  TestCase(
+    name="distinct_engine_lists",
+    engine_lists=[["mlx"], ["tinygrad"]],
+    expected_models_contains=["llama-3.2-1b"],
+    min_count=5,
+    exact_count=None,
+    max_count=10
+  ),
+  TestCase(
+    name="no_engines",
+    engine_lists=[],
+    expected_models_contains=None,
+    min_count=None,
+    exact_count=len(model_cards),
+    max_count=None
+  ),
+  TestCase(
+    name="nonexistent_engine",
+    engine_lists=[["NonexistentEngine"]],
+    expected_models_contains=[],
+    min_count=None,
+    exact_count=0,
+    max_count=None
+  ),
+  TestCase(
+    name="dummy_engine",
+    engine_lists=[["dummy"]],
+    expected_models_contains=["dummy"],
+    min_count=None,
+    exact_count=1,
+    max_count=None
+  ),
+]
+
+class TestModelHelpers(unittest.TestCase):
+  def test_get_supported_models(self):
+    for case in test_cases:
+      with self.subTest(f"{case.name}_short_names"):
+        result = get_supported_models(case.engine_lists)
+        self._verify_results(case, result)
+
+      with self.subTest(f"{case.name}_class_names"):
+        class_name_lists = expand_engine_lists(case.engine_lists)
+        result = get_supported_models(class_name_lists)
+        self._verify_results(case, result)
+
+  def _verify_results(self, case, result):
+    if case.expected_models_contains:
+      for model in case.expected_models_contains:
+        self.assertIn(model, result)
+
+    if case.min_count:
+      self.assertGreater(len(result), case.min_count)
+
+    if case.exact_count is not None:
+      self.assertEqual(len(result), case.exact_count)
+
+    # Special case for distinct lists test
+    if case.name == "distinct_engine_lists":
+      self.assertLess(len(result), 10)
+      self.assertNotIn("mistral-nemo", result)
+
+    if case.max_count:
+      self.assertLess(len(result), case.max_count)
+
+if __name__ == '__main__':
+  unittest.main()

+ 8 - 3
test/test_tokenizers.py

@@ -1,7 +1,7 @@
 import os
 import os
 import re
 import re
 from transformers import AutoTokenizer, AutoProcessor
 from transformers import AutoTokenizer, AutoProcessor
-from exo.models import model_base_shards
+from exo.models import model_cards
 
 
 
 
 def test_tokenizer(name, tokenizer, verbose=False):
 def test_tokenizer(name, tokenizer, verbose=False):
@@ -24,9 +24,14 @@ def test_tokenizer(name, tokenizer, verbose=False):
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
     assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
 
 
-ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*"]
+ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
 ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
-models = [shard.model_id for shards in model_base_shards.values() for shard in shards.values() if not ignore_pattern.match(shard.model_id)]
+models = []
+for model_id in model_cards:
+  for engine_type, repo_id in model_cards[model_id].get("repo", {}).items():
+    if not ignore_pattern.match(repo_id):
+      models.append(repo_id)
+models = list(set(models))
 
 
 verbose = os.environ.get("VERBOSE", "0").lower() == "1"
 verbose = os.environ.get("VERBOSE", "0").lower() == "1"
 for m in models:
 for m in models:

Algúns arquivos non se mostraron porque demasiados arquivos cambiaron neste cambio