Răsfoiți Sursa

Merge remote-tracking branch 'upstream/main' into tidy/models

thenatlog 9 luni în urmă
părinte
comite
f5d5e44146

+ 0 - 472
.pylintrc

@@ -1,472 +0,0 @@
-[MASTER]
-
-# A comma-separated list of package or module names from where C extensions may
-# be loaded. Extensions are loading into the active Python interpreter and may
-# run arbitrary code
-extension-pkg-whitelist=scipy,cereal.messaging.messaging_pyx,PyQt5,av
-
-# Add files or directories to the blacklist. They should be base names, not
-# paths.
-ignore=CVS
-
-# Add files or directories matching the regex patterns to the blacklist. The
-# regex matches against base names, not paths.
-ignore-patterns=.*node_service_pb2.*
-
-# Python code to execute, usually for sys.path manipulation such as
-# pygtk.require().
-#init-hook=
-
-# Use multiple processes to speed up Pylint.
-jobs=4
-
-# List of plugins (as comma separated values of python modules names) to load,
-# usually to register additional checkers.
-load-plugins=
-
-# Pickle collected data for later comparisons.
-persistent=yes
-
-# Specify a configuration file.
-#rcfile=
-
-# When enabled, pylint would attempt to guess common misconfiguration and emit
-# user-friendly hints instead of false-positive error messages
-suggestion-mode=yes
-
-# Allow loading of arbitrary C extensions. Extensions are imported into the
-# active Python interpreter and may run arbitrary code.
-unsafe-load-any-extension=no
-
-
-[MESSAGES CONTROL]
-
-# Only show warnings with the listed confidence levels. Leave empty to show
-# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
-confidence=
-
-# Disable the message, report, category or checker with the given id(s). You
-# can either give multiple identifiers separated by comma (,) or put this
-# option multiple times (only on the command line, not in the configuration
-# file where it should appear only once).You can also use "--disable=all" to
-# disable everything first and then reenable specific checks. For example, if
-# you want to run only the similarities checker, you can use "--disable=all
-# --enable=similarities". If you want to run only the classes checker, but have
-# no Warning level messages displayed, use"--disable=all --enable=classes
-# --disable=W"
-disable=C,R,W0613,W0511,W0212,W0201,W0106,W0603,W0621,W0703,W1201,W1203,E1136,W1514,E1101,W0221,W0105,E0401
-# E1101 for function binding
-# W0221 for Function class
-# W0105 for comment strings
-# E0401 for missing imports
-
-# Enable the message, report, category or checker with the given id(s). You can
-# either give multiple identifier separated by comma (,) or put this option
-# multiple time (only on the command line, not in the configuration file where
-# it should appear only once). See also the "--disable" option for examples.
-enable=c-extension-no-member,use-a-generator, no-else-return
-
-
-[REPORTS]
-
-# Python expression which should return a note less than 10 (10 is the highest
-# note). You have access to the variables errors warning, statement which
-# respectively contain the number of errors / warnings messages and the total
-# number of statements analyzed. This is used by the global evaluation report
-# (RP0004).
-evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
-
-# Template used to display messages. This is a python new-style format string
-# used to format the message information. See doc for all details
-#msg-template=
-
-# Set the output format. Available formats are text, parseable, colorized, json
-# and msvs (visual studio).You can also give a reporter class, eg
-# mypackage.mymodule.MyReporterClass.
-output-format=text
-
-# Tells whether to display a full report or only the messages
-reports=no
-
-# Activate the evaluation score.
-score=yes
-
-
-[REFACTORING]
-
-# Maximum number of nested blocks for function / method body
-max-nested-blocks=5
-
-# Complete name of functions that never returns. When checking for
-# inconsistent-return-statements if a never returning function is called then
-# it will be considered as an explicit return statement and no message will be
-# printed.
-never-returning-functions=optparse.Values,sys.exit
-
-
-[LOGGING]
-
-# Logging modules to check that the string format arguments are in logging
-# function parameter format
-logging-modules=logging
-
-
-[SPELLING]
-
-# Limits count of emitted suggestions for spelling mistakes
-max-spelling-suggestions=4
-
-# Spelling dictionary name. Available dictionaries: none. To make it working
-# install python-enchant package.
-spelling-dict=
-
-# List of comma separated words that should not be checked.
-spelling-ignore-words=
-
-# A path to a file that contains private dictionary; one word per line.
-spelling-private-dict-file=
-
-# Tells whether to store unknown words to indicated private dictionary in
-# --spelling-private-dict-file option instead of raising a message.
-spelling-store-unknown-words=no
-
-
-[MISCELLANEOUS]
-
-# List of note tags to take in consideration, separated by a comma.
-notes=FIXME,
-      XXX,
-      TODO
-
-
-[SIMILARITIES]
-
-# Ignore comments when computing similarities.
-ignore-comments=yes
-
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
-
-# Ignore imports when computing similarities.
-ignore-imports=no
-
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-
-
-[TYPECHECK]
-
-# List of decorators that produce context managers, such as
-# contextlib.contextmanager. Add to this list to register other decorators that
-# produce valid context managers.
-contextmanager-decorators=contextlib.contextmanager
-
-# List of members which are set dynamically and missed by pylint inference
-# system, and so shouldn't trigger E1101 when accessed. Python regular
-# expressions are accepted.
-generated-members=capnp.* cereal.* pygame.* zmq.* setproctitle.* smbus2.* usb1.* serial.* cv2.* ft4222.* carla.*
-
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members=yes
-
-# This flag controls whether pylint should warn about no-member and similar
-# checks whenever an opaque object is returned when inferring. The inference
-# can return multiple potential results while evaluating a Python object, but
-# some branches might not be evaluated, which results in partial inference. In
-# that case, it might be useful to still emit no-member and other checks for
-# the rest of the inferred objects.
-ignore-on-opaque-inference=yes
-
-# List of class names for which member attributes should not be checked (useful
-# for classes with dynamically set attributes). This supports the use of
-# qualified names.
-ignored-classes=optparse.Values,thread._local,_thread._local
-
-# List of module names for which member attributes should not be checked
-# (useful for modules/projects where namespaces are manipulated during runtime
-# and thus existing member attributes cannot be deduced by static analysis. It
-# supports qualified module names, as well as Unix pattern matching.
-ignored-modules=flask setproctitle usb1 flask.ext.socketio smbus2 usb1.*
-
-# Show a hint with possible names when a member name was not found. The aspect
-# of finding the hint is based on edit distance.
-missing-member-hint=yes
-
-# The minimum edit distance a name should have in order to be considered a
-# similar match for a missing member name.
-missing-member-hint-distance=1
-
-# The total number of similar names that should be taken in consideration when
-# showing a hint for a missing member.
-missing-member-max-choices=1
-
-
-[VARIABLES]
-
-# List of additional names supposed to be defined in builtins. Remember that
-# you should avoid to define new builtins when possible.
-additional-builtins=
-
-# Tells whether unused global variables should be treated as a violation.
-allow-global-unused-variables=yes
-
-# List of strings which can identify a callback function by name. A callback
-# name must start or end with one of those strings.
-callbacks=cb_,
-          _cb
-
-# A regular expression matching the name of dummy variables (i.e. expectedly
-# not used).
-dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
-
-# Argument names that match this expression will be ignored. Default to name
-# with leading underscore
-ignored-argument-names=_.*|^ignored_|^unused_
-
-# Tells whether we should check for unused import in __init__ files.
-init-import=no
-
-# List of qualified module names which can have objects that can redefine
-# builtins.
-redefining-builtins-modules=six.moves,past.builtins,future.builtins
-
-
-[FORMAT]
-
-# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
-expected-line-ending-format=
-
-# Regexp for a line that is allowed to be longer than the limit.
-ignore-long-lines=^\s*(# )?<?https?://\S+>?$
-
-# Number of spaces of indent required inside a hanging  or continued line.
-indent-after-paren=4
-
-# String used as indentation unit. This is usually "    " (4 spaces) or "\t" (1
-# tab).
-indent-string='  '
-
-# Maximum number of characters on a single line.
-max-line-length=150
-
-# Maximum number of lines in a module
-max-module-lines=1000
-
-# Allow the body of a class to be on the same line as the declaration if body
-# contains single statement.
-single-line-class-stmt=no
-
-# Allow the body of an if to be on the same line as the test if there is no
-# else.
-single-line-if-stmt=no
-
-
-[BASIC]
-
-# Naming style matching correct argument names
-argument-naming-style=snake_case
-
-# Regular expression matching correct argument names. Overrides argument-
-# naming-style
-#argument-rgx=
-
-# Naming style matching correct attribute names
-attr-naming-style=snake_case
-
-# Regular expression matching correct attribute names. Overrides attr-naming-
-# style
-#attr-rgx=
-
-# Bad variable names which should always be refused, separated by a comma
-bad-names=foo,
-          bar,
-          baz,
-          toto,
-          tutu,
-          tata
-
-# Naming style matching correct class attribute names
-class-attribute-naming-style=any
-
-# Regular expression matching correct class attribute names. Overrides class-
-# attribute-naming-style
-#class-attribute-rgx=
-
-# Naming style matching correct class names
-class-naming-style=PascalCase
-
-# Regular expression matching correct class names. Overrides class-naming-style
-#class-rgx=
-
-# Naming style matching correct constant names
-const-naming-style=UPPER_CASE
-
-# Regular expression matching correct constant names. Overrides const-naming-
-# style
-#const-rgx=
-
-# Minimum line length for functions/classes that require docstrings, shorter
-# ones are exempt.
-docstring-min-length=-1
-
-# Naming style matching correct function names
-function-naming-style=snake_case
-
-# Regular expression matching correct function names. Overrides function-
-# naming-style
-#function-rgx=
-
-# Good variable names which should always be accepted, separated by a comma
-good-names=i,
-           j,
-           k,
-           ex,
-           Run,
-           _
-
-# Include a hint for the correct naming format with invalid-name
-include-naming-hint=no
-
-# Naming style matching correct inline iteration names
-inlinevar-naming-style=any
-
-# Regular expression matching correct inline iteration names. Overrides
-# inlinevar-naming-style
-#inlinevar-rgx=
-
-# Naming style matching correct method names
-method-naming-style=snake_case
-
-# Regular expression matching correct method names. Overrides method-naming-
-# style
-#method-rgx=
-
-# Naming style matching correct module names
-module-naming-style=snake_case
-
-# Regular expression matching correct module names. Overrides module-naming-
-# style
-#module-rgx=
-
-# Colon-delimited sets of names that determine each other's naming style when
-# the name regexes allow several styles.
-name-group=
-
-# Regular expression which should only match function or class names that do
-# not require a docstring.
-no-docstring-rgx=^_
-
-# List of decorators that produce properties, such as abc.abstractproperty. Add
-# to this list to register other decorators that produce valid properties.
-property-classes=abc.abstractproperty
-
-# Naming style matching correct variable names
-variable-naming-style=snake_case
-
-# Regular expression matching correct variable names. Overrides variable-
-# naming-style
-#variable-rgx=
-
-
-[DESIGN]
-
-# Maximum number of arguments for function / method
-max-args=5
-
-# Maximum number of attributes for a class (see R0902).
-max-attributes=7
-
-# Maximum number of boolean expressions in a if statement
-max-bool-expr=5
-
-# Maximum number of branch for function / method body
-max-branches=12
-
-# Maximum number of locals for function / method body
-max-locals=15
-
-# Maximum number of parents for a class (see R0901).
-max-parents=7
-
-# Maximum number of public methods for a class (see R0904).
-max-public-methods=20
-
-# Maximum number of return / yield for function / method body
-max-returns=6
-
-# Maximum number of statements in function / method body
-max-statements=50
-
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=2
-
-
-[CLASSES]
-
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,
-                      __new__,
-                      setUp
-
-# List of member names, which should be excluded from the protected access
-# warning.
-exclude-protected=_asdict,
-                  _fields,
-                  _replace,
-                  _source,
-                  _make
-
-# List of valid names for the first argument in a class method.
-valid-classmethod-first-arg=cls
-
-# List of valid names for the first argument in a metaclass class method.
-valid-metaclass-classmethod-first-arg=mcs
-
-
-[IMPORTS]
-
-# Allow wildcard imports from modules that define __all__.
-allow-wildcard-with-all=no
-
-# Analyse import fallback blocks. This can be used to support both Python 2 and
-# 3 compatible code, which means that the block might have code that exists
-# only in one or another interpreter, leading to false positives when analysed.
-analyse-fallback-blocks=no
-
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules=regsub,
-                   TERMIOS,
-                   Bastion,
-                   rexec
-
-# Create a graph of external dependencies in the given file (report RP0402 must
-# not be disabled)
-ext-import-graph=
-
-# Create a graph of every (i.e. internal and external) dependencies in the
-# given file (report RP0402 must not be disabled)
-import-graph=
-
-# Create a graph of internal dependencies in the given file (report RP0402 must
-# not be disabled)
-int-import-graph=
-
-# Force import order to recognize a module as part of the standard
-# compatibility libraries.
-known-standard-library=
-
-# Force import order to recognize a module as part of a third party library.
-known-third-party=enchant
-
-[STRING]
-
-# This flag controls whether the implicit-str-concat should generate a warning
-# on implicit string concatenation in sequences defined over several lines.
-check-str-concat-over-line-jumps=yes
-
-[EXCEPTIONS]
-
-# Exceptions that will emit a warning when being caught. Defaults to
-# "Exception"
-overgeneral-exceptions=builtins.Exception

+ 14 - 0
README.md

@@ -222,6 +222,20 @@ For the **tinygrad** inference engine specifically, there is a separate DEBUG fl
 TINYGRAD_DEBUG=2 exo
 ```
 
+## Formatting
+
+We use [yapf](https://github.com/google/yapf) to format the code. To format the code, first install the formatting requirements:
+
+```sh
+pip3 install -e '.[formatting]'
+```
+
+Then run the formatting script:
+
+```sh
+python3 format.py ./exo
+```
+
 ## Known Issues
 
 - On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:

+ 4 - 7
exo/api/chatgpt_api.py

@@ -178,7 +178,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
 
-    self.static_dir = Path(__file__).parent.parent / "tinychat"
+    self.static_dir = Path(__file__).parent.parent/"tinychat"
     self.app.router.add_get("/", self.handle_root)
     self.app.router.add_static("/", self.static_dir, name="static")
 
@@ -191,6 +191,7 @@ class ChatGPTAPI:
         return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
       except asyncio.TimeoutError:
         return web.json_response({"detail": "Request timed out"}, status=408)
+
     return middleware
 
   async def log_request(self, app, handler):
@@ -204,7 +205,7 @@ class ChatGPTAPI:
     return web.FileResponse(self.static_dir/"index.html")
 
   async def handle_get_models(self, request):
-    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True } for model_name, _ in model_base_shards.items()])
+    return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_base_shards.items()])
 
   async def handle_post_chat_token_encode(self, request):
     data = await request.json()
@@ -222,7 +223,6 @@ class ChatGPTAPI:
         print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
     return web.json_response(progress_data)
 
-
   async def handle_post_chat_completions(self, request):
     data = await request.json()
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
@@ -270,10 +270,7 @@ class ChatGPTAPI:
     if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=} {image_str=}")
 
     try:
-      await asyncio.wait_for(
-        asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))),
-        timeout=self.response_timeout
-      )
+      await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, image_str, request_id=request_id))), timeout=self.response_timeout)
 
       if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
 

+ 4 - 2
exo/download/hf/hf_helpers.py

@@ -70,8 +70,10 @@ def _add_wildcard_to_directories(pattern: str) -> str:
     return pattern + "*"
   return pattern
 
+
 def get_hf_endpoint() -> str:
-    return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+  return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
+
 
 def get_hf_home() -> Path:
   """Get the Hugging Face home directory."""
@@ -394,7 +396,7 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
 
 
 def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
-  default_patterns = set(["*.json","*.py","tokenizer.model","*.tiktoken","*.txt"])
+  default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
   shard_specific_patterns = set()
   if weight_map:
     for tensor_name, filename in weight_map.items():

+ 1 - 0
exo/download/shard_download.py

@@ -25,6 +25,7 @@ class ShardDownloader(ABC):
   def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
     pass
 
+
 class NoopShardDownloader(ShardDownloader):
   async def ensure_shard(self, shard: Shard) -> Path:
     return Path("/tmp/noop_shard")

+ 1 - 1
exo/helpers.py

@@ -170,7 +170,7 @@ def is_valid_uuid(val):
 
 
 def get_or_create_node_id():
-  NODE_ID_FILE = Path(tempfile.gettempdir()) / ".exo_node_id"
+  NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
   try:
     if NODE_ID_FILE.is_file():
       with open(NODE_ID_FILE, "r") as f:

+ 1 - 0
exo/inference/dummy_inference_engine.py

@@ -5,6 +5,7 @@ import json
 from exo.inference.inference_engine import InferenceEngine
 from exo.inference.shard import Shard
 
+
 class DummyInferenceEngine(InferenceEngine):
   def __init__(self):
     self.shard = None

+ 2 - 1
exo/inference/mlx/models/qwen2.py

@@ -24,6 +24,7 @@ class ModelArgs(ModelArgs):
 
     self.shard = Shard(**self.shard)
 
+
 class Qwen2Model(nn.Module):
   def __init__(self, args: ModelArgs):
     super().__init__()
@@ -57,7 +58,7 @@ class Qwen2Model(nn.Module):
       mask = create_attention_mask(h, cache)
 
     if cache is None:
-      cache = [None] * len(self.layers)
+      cache = [None]*len(self.layers)
 
     for layer, c in zip(self.layers, cache):
       h = layer(h, mask, c)

+ 5 - 1
exo/inference/mlx/sharded_inference_engine.py

@@ -10,6 +10,7 @@ import asyncio
 from concurrent.futures import ThreadPoolExecutor
 from functools import partial
 
+
 class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
@@ -44,7 +45,10 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
 
     if self.shard != shard:
       loop = asyncio.get_running_loop()
-      def load_shard_wrapper(): return asyncio.run(load_shard(model_path, shard))
+
+      def load_shard_wrapper():
+        return asyncio.run(load_shard(model_path, shard))
+
       model_shard, self.tokenizer = await loop.run_in_executor(self.executor, load_shard_wrapper)
       self.stateful_sharded_model = await loop.run_in_executor(self.executor, StatefulShardedModel, shard, model_shard)
       self.shard = shard

+ 1 - 0
exo/inference/mlx/sharded_model.py

@@ -8,6 +8,7 @@ from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
 
+
 # TODO: support a speculative model so we can parallelise compute across devices
 class StatefulShardedModel:
   def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):

+ 47 - 42
exo/inference/test_dummy_inference_engine.py

@@ -4,53 +4,58 @@ import numpy as np
 from exo.inference.dummy_inference_engine import DummyInferenceEngine
 from exo.inference.shard import Shard
 
+
 class MockShardDownloader:
-    async def ensure_shard(self, shard):
-        pass
+  async def ensure_shard(self, shard):
+    pass
+
+
 @pytest.mark.asyncio
 async def test_dummy_inference_specific():
-    engine = DummyInferenceEngine(MockShardDownloader())
-    test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    test_prompt = "This is a test prompt"
-    
-    result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
-    
-    print(f"Inference result shape: {result.shape}")
-    print(f"Inference state: {state}")
-    print(f"Is finished: {is_finished}")
-    
-    assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
-    assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
+  engine = DummyInferenceEngine(MockShardDownloader())
+  test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+  test_prompt = "This is a test prompt"
+
+  result, state, is_finished = await engine.infer_prompt("test_request", test_shard, test_prompt)
+
+  print(f"Inference result shape: {result.shape}")
+  print(f"Inference state: {state}")
+  print(f"Is finished: {is_finished}")
+
+  assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
+  assert isinstance(json.loads(state), dict), "State should be a valid JSON string"
+  assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
 
 @pytest.mark.asyncio
 async def test_dummy_inference_engine():
-    # Initialize the DummyInferenceEngine
-    engine = DummyInferenceEngine(MockShardDownloader())
-    
-    # Create a test shard
-    shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
-    
-    # Test infer_prompt
-    output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
-    
-    assert isinstance(output, np.ndarray), "Output should be a numpy array"
-    assert output.ndim == 2, "Output should be 2-dimensional"
-    assert isinstance(state, str), "State should be a string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
-
-    # Test infer_tensor
-    input_tensor = np.array([[1, 2, 3]])
-    output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
-    
-    assert isinstance(output, np.ndarray), "Output should be a numpy array"
-    assert output.ndim == 2, "Output should be 2-dimensional"
-    assert isinstance(state, str), "State should be a string"
-    assert isinstance(is_finished, bool), "is_finished should be a boolean"
-
-    print("All tests passed!")
+  # Initialize the DummyInferenceEngine
+  engine = DummyInferenceEngine(MockShardDownloader())
+
+  # Create a test shard
+  shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
+
+  # Test infer_prompt
+  output, state, is_finished = await engine.infer_prompt("test_id", shard, "Test prompt")
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+  assert isinstance(state, str), "State should be a string"
+  assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+  # Test infer_tensor
+  input_tensor = np.array([[1, 2, 3]])
+  output, state, is_finished = await engine.infer_tensor("test_id", shard, input_tensor)
+
+  assert isinstance(output, np.ndarray), "Output should be a numpy array"
+  assert output.ndim == 2, "Output should be 2-dimensional"
+  assert isinstance(state, str), "State should be a string"
+  assert isinstance(is_finished, bool), "is_finished should be a boolean"
+
+  print("All tests passed!")
+
 
 if __name__ == "__main__":
-    import asyncio
-    asyncio.run(test_dummy_inference_engine())
-    asyncio.run(test_dummy_inference_specific())
+  import asyncio
+  asyncio.run(test_dummy_inference_engine())
+  asyncio.run(test_dummy_inference_specific())

+ 2 - 12
exo/inference/test_inference_engine.py

@@ -44,12 +44,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
   assert np.array_equal(next_resp_full, resp4)
 
 
-asyncio.run(test_inference_engine(
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  MLXDynamicShardInferenceEngine(HFShardDownloader()),
-  "mlx-community/Llama-3.2-1B-Instruct-4bit",
-  16
-))
+asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "mlx-community/Llama-3.2-1B-Instruct-4bit", 16))
 
 if os.getenv("RUN_TINYGRAD", default="0") == "1":
   import tinygrad
@@ -57,10 +52,5 @@ if os.getenv("RUN_TINYGRAD", default="0") == "1":
   from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
   tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
   asyncio.run(
-    test_inference_engine(
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      TinygradDynamicShardInferenceEngine(HFShardDownloader()),
-      "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-      32
-    )
+    test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", 32)
   )

+ 6 - 1
exo/inference/tokenizers.py

@@ -7,14 +7,18 @@ from transformers import AutoTokenizer, AutoProcessor
 from exo.download.hf.hf_helpers import get_local_snapshot_dir
 from exo.helpers import DEBUG
 
+
 class DummyTokenizer:
   def __init__(self):
     self.eos_token_id = 0
+
   def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
-    return [1,2,3]
+    return [1, 2, 3]
+
   def decode(self, tokens):
     return "dummy"
 
+
 async def resolve_tokenizer(model_id: str):
   if model_id == "dummy":
     return DummyTokenizer()
@@ -29,6 +33,7 @@ async def resolve_tokenizer(model_id: str):
     if DEBUG >= 5: traceback.print_exc()
   return await _resolve_tokenizer(model_id)
 
+
 async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
   try:
     if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")

+ 11 - 10
exo/main.py

@@ -59,7 +59,8 @@ print_yellow_exo()
 system_info = get_system_info()
 print(f"Detected system: {system_info}")
 
-shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
+shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
+                                                      max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
 inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
 print(f"Inference engine name after selection: {inference_engine_name}")
 
@@ -135,16 +136,14 @@ if args.prometheus_client_port:
 
 last_broadcast_time = 0
 
+
 def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
-    global last_broadcast_time
-    current_time = time.time()
-    if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
-        last_broadcast_time = current_time
-        asyncio.create_task(node.broadcast_opaque_status("", json.dumps({
-            "type": "download_progress",
-            "node_id": node.id,
-            "progress": event.to_dict()
-        })))
+  global last_broadcast_time
+  current_time = time.time()
+  if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
+    last_broadcast_time = current_time
+    asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
+
 
 shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
 
@@ -161,6 +160,7 @@ async def shutdown(signal, loop):
     await server.stop()
     loop.stop()
 
+
 async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
     shard = model_base_shards.get(model_name, {}).get(inference_engine.__class__.__name__)
     if not shard:
@@ -234,5 +234,6 @@ def run():
         loop.run_until_complete(shutdown(signal.SIGTERM, loop))
         loop.close()
 
+
 if __name__ == "__main__":
     run()

+ 45 - 221
exo/models.py

@@ -1,225 +1,49 @@
 from exo.inference.shard import Shard
 
 model_base_shards = {
-    ### LLaMA Models
-    "llama-3.2-1b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Llama-3.2-1B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=16,
-        ),
-    },
-    "llama-3.2-3b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Llama-3.2-3B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=28,
-        ),
-    },
-    "llama-3.1-8b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=32,
-        ),
-        "TinygradDynamicShardInferenceEngine": Shard(
-            model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
-            start_layer=0,
-            end_layer=0,
-            n_layers=32,
-        ),
-    },
-    "llama-3.1-70b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-        "TinygradDynamicShardInferenceEngine": Shard(
-            model_id="NousResearch/Meta-Llama-3.1-70B-Instruct",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-    },
-    "llama-3.1-70b-bf16": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-        "TinygradDynamicShardInferenceEngine": Shard(
-            model_id="NousResearch/Meta-Llama-3.1-70B-Instruct",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-    },
-    "llama-3.1-405b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Meta-Llama-3.1-405B-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=126,
-        ),
-    },
-    "llama-3-8b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=32,
-        ),
-        "TinygradDynamicShardInferenceEngine": Shard(
-            model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
-            start_layer=0,
-            end_layer=0,
-            n_layers=32,
-        ),
-    },
-    "llama-3-70b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-        "TinygradDynamicShardInferenceEngine": Shard(
-            model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-    },
-    ### Mistral Models
-    "mistral-nemo": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=40,
-        ),
-    },
-    "mistral-large": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Mistral-Large-Instruct-2407-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=88,
-        ),
-    },
-    ### DeepSeek Models
-    "deepseek-coder-v2-lite": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx",
-            start_layer=0,
-            end_layer=0,
-            n_layers=27,
-        ),
-    },
-    "deepseek-coder-v2.5": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64",
-            start_layer=0,
-            end_layer=0,
-            n_layers=60,
-        ),
-    },
-    ### Llava Models
-    "llava-1.5-7b-hf": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="llava-hf/llava-1.5-7b-hf",
-            start_layer=0,
-            end_layer=0,
-            n_layers=32,
-        ),
-    },
-    ### Qwen Models
-    "qwen-2.5-coder-1.5b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=28,
-        ),
-    },
-    "qwen-2.5-coder-7b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=28,
-        ),
-    },
-    "qwen-2.5-7b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Qwen2.5-7B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=28,
-        ),
-    },
-    "qwen-2.5-math-7b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=28,
-        ),
-    },
-    "qwen-2.5-14b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Qwen2.5-14B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=48,
-        ),
-    },
-    "qwen-2.5-72b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Qwen2.5-72B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-    },
-    "qwen-2.5-math-72b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-    },
-    ### Nemotron Models
-    "nemotron-70b": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-    },
-    "nemotron-70b-bf16": {
-        "MLXDynamicShardInferenceEngine": Shard(
-            model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16",
-            start_layer=0,
-            end_layer=0,
-            n_layers=80,
-        ),
-    },
-    ### Dummy Model
-    "dummy": {
-        "DummyInferenceEngine": Shard(
-            model_id="dummy",
-            start_layer=0,
-            end_layer=7,
-            n_layers=8,
-        ),
-    },
+  ### llama
+  "llama-3.2-1b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=16),},
+  "llama-3.2-3b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.2-3B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "llama-3.1-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3.1-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "llama-3.1-70b-bf16": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
+  },
+  "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
+  "llama-3-8b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", start_layer=0, end_layer=0, n_layers=32),
+  },
+  "llama-3-70b": {
+    "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", start_layer=0, end_layer=0, n_layers=80),
+  },
+  ### mistral
+  "mistral-nemo": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),},
+  "mistral-large": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),},
+  ### deepseek
+  "deepseek-coder-v2-lite": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", start_layer=0, end_layer=0, n_layers=27),},
+  "deepseek-coder-v2.5": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", start_layer=0, end_layer=0, n_layers=60),},
+  ### llava
+  "llava-1.5-7b-hf": {"MLXDynamicShardInferenceEngine": Shard(model_id="llava-hf/llava-1.5-7b-hf", start_layer=0, end_layer=0, n_layers=32),},
+  ### qwen
+  "qwen-2.5-coder-1.5b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "qwen-2.5-coder-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "qwen-2.5-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "qwen-2.5-math-7b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-7B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=28),},
+  "qwen-2.5-14b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-14B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=48),},
+  "qwen-2.5-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
+  "qwen-2.5-math-72b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Qwen2.5-Math-72B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),},
+  ### nemotron
+  "nemotron-70b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", start_layer=0, end_layer=0, n_layers=80),},
+  "nemotron-70b-bf16": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", start_layer=0, end_layer=0, n_layers=80),},
+  # dummy
+  "dummy": {"DummyInferenceEngine": Shard(model_id="dummy", start_layer=0, end_layer=7, n_layers=8),},
 }

+ 4 - 2
exo/networking/grpc/grpc_peer_handle.py

@@ -9,7 +9,7 @@ from . import node_service_pb2_grpc
 from ..peer_handle import PeerHandle
 from exo.inference.shard import Shard
 from exo.topology.topology import Topology
-from exo.topology.device_capabilities import DeviceCapabilities
+from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from exo.helpers import DEBUG
 
 
@@ -117,7 +117,9 @@ class GRPCPeerHandle(PeerHandle):
     response = await self.stub.CollectTopology(request)
     topology = Topology()
     for node_id, capabilities in response.nodes.items():
-      device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops)
+      device_capabilities = DeviceCapabilities(
+        model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+      )
       topology.update_node(node_id, device_capabilities)
     for node_id, peers in response.peer_graph.items():
       for peer_id in peers.peer_ids:

Fișier diff suprimat deoarece este prea mare
+ 0 - 3
exo/networking/grpc/node_service_pb2.py


+ 263 - 314
exo/networking/grpc/node_service_pb2_grpc.py

@@ -12,349 +12,298 @@ SCHEDULED_RELEASE_DATE = 'June 25, 2024'
 _version_not_supported = False
 
 try:
-    from grpc._utilities import first_version_is_lower
-    _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
+  from grpc._utilities import first_version_is_lower
+  _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
 except ImportError:
-    _version_not_supported = True
+  _version_not_supported = True
 
 if _version_not_supported:
-    warnings.warn(
-        f'The grpc package installed is at version {GRPC_VERSION},'
-        + f' but the generated code in node_service_pb2_grpc.py depends on'
-        + f' grpcio>={GRPC_GENERATED_VERSION}.'
-        + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
-        + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
-        + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
-        + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
-        RuntimeWarning
-    )
+  warnings.warn(
+    f'The grpc package installed is at version {GRPC_VERSION},' + f' but the generated code in node_service_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' +
+    f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' +
+    f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', RuntimeWarning
+  )
 
 
 class NodeServiceStub(object):
-    """Missing associated documentation comment in .proto file."""
-
-    def __init__(self, channel):
-        """Constructor.
+  """Missing associated documentation comment in .proto file."""
+  def __init__(self, channel):
+    """Constructor.
 
         Args:
             channel: A grpc.Channel.
         """
-        self.SendPrompt = channel.unary_unary(
-                '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.SendTensor = channel.unary_unary(
-                '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
-                _registered_method=True)
-        self.GetInferenceResult = channel.unary_unary(
-                '/node_service.NodeService/GetInferenceResult',
-                request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.InferenceResult.FromString,
-                _registered_method=True)
-        self.CollectTopology = channel.unary_unary(
-                '/node_service.NodeService/CollectTopology',
-                request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
-                _registered_method=True)
-        self.SendResult = channel.unary_unary(
-                '/node_service.NodeService/SendResult',
-                request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.SendOpaqueStatus = channel.unary_unary(
-                '/node_service.NodeService/SendOpaqueStatus',
-                request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
-        self.HealthCheck = channel.unary_unary(
-                '/node_service.NodeService/HealthCheck',
-                request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
-                response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
-                _registered_method=True)
+    self.SendPrompt = channel.unary_unary(
+      '/node_service.NodeService/SendPrompt',
+      request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.SendTensor = channel.unary_unary(
+      '/node_service.NodeService/SendTensor',
+      request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Tensor.FromString,
+      _registered_method=True
+    )
+    self.GetInferenceResult = channel.unary_unary(
+      '/node_service.NodeService/GetInferenceResult',
+      request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.InferenceResult.FromString,
+      _registered_method=True
+    )
+    self.CollectTopology = channel.unary_unary(
+      '/node_service.NodeService/CollectTopology',
+      request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Topology.FromString,
+      _registered_method=True
+    )
+    self.SendResult = channel.unary_unary(
+      '/node_service.NodeService/SendResult',
+      request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
+    self.SendOpaqueStatus = channel.unary_unary(
+      '/node_service.NodeService/SendOpaqueStatus',
+      request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      response_deserializer=node__service__pb2.Empty.FromString,
+      _registered_method=True
+    )
+    self.HealthCheck = channel.unary_unary(
+      '/node_service.NodeService/HealthCheck',
+      request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
+      response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
+      _registered_method=True
+    )
 
 
 class NodeServiceServicer(object):
+  """Missing associated documentation comment in .proto file."""
+  def SendPrompt(self, request, context):
     """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendPrompt(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
-    def SendTensor(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendTensor(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def GetInferenceResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def GetInferenceResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def CollectTopology(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def CollectTopology(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendResult(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendResult(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def SendOpaqueStatus(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def SendOpaqueStatus(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
-    def HealthCheck(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
+  def HealthCheck(self, request, context):
+    """Missing associated documentation comment in .proto file."""
+    context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+    context.set_details('Method not implemented!')
+    raise NotImplementedError('Method not implemented!')
 
 
 def add_NodeServiceServicer_to_server(servicer, server):
-    rpc_method_handlers = {
-            'SendPrompt': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'SendTensor': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
-            ),
-            'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.GetInferenceResult,
-                    request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
-                    response_serializer=node__service__pb2.InferenceResult.SerializeToString,
-            ),
-            'CollectTopology': grpc.unary_unary_rpc_method_handler(
-                    servicer.CollectTopology,
-                    request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
-            ),
-            'SendResult': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendResult,
-                    request_deserializer=node__service__pb2.SendResultRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
-                    servicer.SendOpaqueStatus,
-                    request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
-            'HealthCheck': grpc.unary_unary_rpc_method_handler(
-                    servicer.HealthCheck,
-                    request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
-                    response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
-            ),
-    }
-    generic_handler = grpc.method_handlers_generic_handler(
-            'node_service.NodeService', rpc_method_handlers)
-    server.add_generic_rpc_handlers((generic_handler,))
-    server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
+  rpc_method_handlers = {
+    'SendPrompt':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendPrompt,
+        request_deserializer=node__service__pb2.PromptRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'SendTensor':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendTensor,
+        request_deserializer=node__service__pb2.TensorRequest.FromString,
+        response_serializer=node__service__pb2.Tensor.SerializeToString,
+      ),
+    'GetInferenceResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.GetInferenceResult,
+        request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
+        response_serializer=node__service__pb2.InferenceResult.SerializeToString,
+      ),
+    'CollectTopology':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.CollectTopology,
+        request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
+        response_serializer=node__service__pb2.Topology.SerializeToString,
+      ),
+    'SendResult':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendResult,
+        request_deserializer=node__service__pb2.SendResultRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+    'SendOpaqueStatus':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.SendOpaqueStatus,
+        request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
+        response_serializer=node__service__pb2.Empty.SerializeToString,
+      ),
+    'HealthCheck':
+      grpc.unary_unary_rpc_method_handler(
+        servicer.HealthCheck,
+        request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
+        response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
+      ),
+  }
+  generic_handler = grpc.method_handlers_generic_handler('node_service.NodeService', rpc_method_handlers)
+  server.add_generic_rpc_handlers((generic_handler,))
+  server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
 
 
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
 class NodeService(object):
-    """Missing associated documentation comment in .proto file."""
-
-    @staticmethod
-    def SendPrompt(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  """Missing associated documentation comment in .proto file."""
+  @staticmethod
+  def SendPrompt(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendPrompt',
+      node__service__pb2.PromptRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendTensor(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendTensor(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendTensor',
+      node__service__pb2.TensorRequest.SerializeToString,
+      node__service__pb2.Tensor.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def GetInferenceResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GetInferenceResult',
-            node__service__pb2.GetInferenceResultRequest.SerializeToString,
-            node__service__pb2.InferenceResult.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def GetInferenceResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/GetInferenceResult',
+      node__service__pb2.GetInferenceResultRequest.SerializeToString,
+      node__service__pb2.InferenceResult.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def CollectTopology(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/CollectTopology',
-            node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def CollectTopology(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/CollectTopology',
+      node__service__pb2.CollectTopologyRequest.SerializeToString,
+      node__service__pb2.Topology.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendResult(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendResult',
-            node__service__pb2.SendResultRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendResult(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendResult',
+      node__service__pb2.SendResultRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def SendOpaqueStatus(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/SendOpaqueStatus',
-            node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def SendOpaqueStatus(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/SendOpaqueStatus',
+      node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
+      node__service__pb2.Empty.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )
 
-    @staticmethod
-    def HealthCheck(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/HealthCheck',
-            node__service__pb2.HealthCheckRequest.SerializeToString,
-            node__service__pb2.HealthCheckResponse.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
+  @staticmethod
+  def HealthCheck(request, target, options=(), channel_credentials=None, call_credentials=None, insecure=False, compression=None, wait_for_ready=None, timeout=None, metadata=None):
+    return grpc.experimental.unary_unary(
+      request,
+      target,
+      '/node_service.NodeService/HealthCheck',
+      node__service__pb2.HealthCheckRequest.SerializeToString,
+      node__service__pb2.HealthCheckResponse.FromString,
+      options,
+      channel_credentials,
+      insecure,
+      call_credentials,
+      compression,
+      wait_for_ready,
+      timeout,
+      metadata,
+      _registered_method=True
+    )

+ 9 - 7
exo/networking/manual/manual_discovery.py

@@ -19,7 +19,9 @@ class ManualDiscovery(Discovery):
     self.create_peer_handle = create_peer_handle
 
     if node_id not in self.topology.peers:
-      raise ValueError(f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}")
+      raise ValueError(
+        f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
+      )
 
     self.listen_task = None
 
@@ -42,7 +44,6 @@ class ManualDiscovery(Discovery):
     if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
     return list(self.known_peers.values())
 
-
   async def task_find_peers_from_config(self):
     if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
     while True:
@@ -52,18 +53,19 @@ class ManualDiscovery(Discovery):
           peer = self.known_peers.get(peer_id)
           if not peer:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
-            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)  
+            peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", peer_config.device_capabilities)
           is_healthy = await peer.health_check()
           if is_healthy:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
             self.known_peers[peer_id] = peer
           else:
             if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
-            try: del self.known_peers[peer_id]
-            except KeyError: pass
+            try:
+              del self.known_peers[peer_id]
+            except KeyError:
+              pass
         except Exception as e:
-            if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
+          if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
       await asyncio.sleep(1.0)
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
-

+ 0 - 1
exo/networking/manual/network_topology_config.py

@@ -17,7 +17,6 @@ class NetworkTopology(BaseModel):
   """
   node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
   """
-
   @classmethod
   def from_path(cls, path: str) -> "NetworkTopology":
     try:

+ 1 - 0
exo/networking/peer_handle.py

@@ -5,6 +5,7 @@ from exo.inference.shard import Shard
 from exo.topology.device_capabilities import DeviceCapabilities
 from exo.topology.topology import Topology
 
+
 class PeerHandle(ABC):
   @abstractmethod
   def id(self) -> str:

+ 11 - 11
exo/networking/tailscale/tailscale_discovery.py

@@ -8,6 +8,7 @@ from exo.topology.device_capabilities import DeviceCapabilities, device_capabili
 from exo.helpers import DEBUG, DEBUG_DISCOVERY
 from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
 
+
 class TailscaleDiscovery(Discovery):
   def __init__(
     self,
@@ -69,14 +70,11 @@ class TailscaleDiscovery(Discovery):
         devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
         current_time = time.time()
 
-        active_devices = {
-          name: device for name, device in devices.items()
-          if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30
-        }
+        active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
 
         if DEBUG_DISCOVERY >= 4: print(f"Found tailscale devices: {devices}")
         if DEBUG_DISCOVERY >= 2: print(f"Active tailscale devices: {len(active_devices)}/{len(devices)}")
-        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time  - device.last_seen.timestamp()) for device in devices.values()])
+        if DEBUG_DISCOVERY >= 2: print("Time since last seen tailscale devices", [(current_time - device.last_seen.timestamp()) for device in devices.values()])
 
         for device in active_devices.values():
           if device.name == self.node_id: continue
@@ -141,7 +139,13 @@ class TailscaleDiscovery(Discovery):
         for peer_id, should_remove in zip(peer_ids, results):
           if should_remove: peers_to_remove.append(peer_id)
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}" for peer_handle, connected_at, last_seen in self.known_peers.values() })
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}"
+              for peer_handle, connected_at, last_seen in self.known_peers.values()
+            }
+          )
 
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers:
@@ -164,9 +168,5 @@ class TailscaleDiscovery(Discovery):
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       return True
 
-    should_remove = (
-      (not is_connected and current_time - connected_at > self.discovery_timeout) or
-      (current_time - last_seen > self.discovery_timeout) or
-      (not health_ok)
-    )
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
     return should_remove

+ 15 - 26
exo/networking/tailscale/tailscale_helpers.py

@@ -7,6 +7,7 @@ from exo.helpers import DEBUG_DISCOVERY
 from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
 from datetime import datetime, timezone
 
+
 class Device:
   def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
     self.device_id = device_id
@@ -16,12 +17,7 @@ class Device:
 
   @classmethod
   def from_dict(cls, data: Dict[str, Any]) -> 'Device':
-    return cls(
-      device_id=data.get('id', ''),
-      name=data.get('name', ''),
-      addresses=data.get('addresses', []),
-      last_seen=cls.parse_datetime(data.get('lastSeen'))
-    )
+    return cls(device_id=data.get('id', ''), name=data.get('name', ''), addresses=data.get('addresses', []), last_seen=cls.parse_datetime(data.get('lastSeen')))
 
   @staticmethod
   def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
@@ -29,13 +25,10 @@ class Device:
       return None
     return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
 
+
 async def get_device_id() -> str:
   try:
-    process = await asyncio.create_subprocess_exec(
-      'tailscale', 'status', '--json',
-      stdout=asyncio.subprocess.PIPE,
-      stderr=asyncio.subprocess.PIPE
-    )
+    process = await asyncio.create_subprocess_exec('tailscale', 'status', '--json', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
     stdout, stderr = await process.communicate()
     if process.returncode != 0:
       raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.")
@@ -45,22 +38,16 @@ async def get_device_id() -> str:
   except Exception as e:
     raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli")
 
+
 async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities):
   async with aiohttp.ClientSession() as session:
     base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}',
-      'Content-Type': 'application/json'
-    }
+    headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
 
     attributes = {
-      "custom:exo_node_id": node_id.replace('-', '_'),
-      "custom:exo_node_port": node_port,
-      "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
-      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model),
-      "custom:exo_device_capability_memory": str(device_capabilities.memory),
-      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16),
-      "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
+      "custom:exo_node_id": node_id.replace('-', '_'), "custom:exo_node_port": node_port, "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip),
+      "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model), "custom:exo_device_capability_memory": str(device_capabilities.memory),
+      "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16), "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32),
       "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8)
     }
 
@@ -73,12 +60,11 @@ async def update_device_attributes(device_id: str, api_key: str, node_id: str, n
         else:
           print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}")
 
+
 async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes"
-    headers = {
-      'Authorization': f'Bearer {api_key}'
-    }
+    headers = {'Authorization': f'Bearer {api_key}'}
     async with session.get(url, headers=headers) as response:
       if response.status == 200:
         data = await response.json()
@@ -100,6 +86,7 @@ async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int,
         print(f"Failed to fetch posture attributes for {device_id}: {response.status}")
         return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0))
 
+
 def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
   result = {}
   prefix = "custom:exo_"
@@ -112,12 +99,14 @@ def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]:
         result[attr_name] = float(value)
   return result
 
+
 def sanitize_attribute(value: str) -> str:
   # Replace invalid characters with underscores
   sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
   # Truncate to 50 characters
   return sanitized_value[:50]
 
+
 async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
   async with aiohttp.ClientSession() as session:
     url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
@@ -133,4 +122,4 @@ async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]
         device = Device.from_dict(device_data)
         devices[device.name] = device
 
-      return devices
+      return devices

+ 2 - 0
exo/networking/tailscale/test_tailscale_discovery.py

@@ -5,6 +5,7 @@ from unittest import mock
 from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
 from exo.networking.peer_handle import PeerHandle
 
+
 class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
   async def asyncSetUp(self):
     self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "")
@@ -37,5 +38,6 @@ class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase):
     # Check if discovered peers are instances of GRPCPeerHandle
     print(peers)
 
+
 if __name__ == '__main__':
   unittest.main()

+ 14 - 15
exo/networking/udp/udp_discovery.py

@@ -9,6 +9,7 @@ from exo.networking.peer_handle import PeerHandle
 from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
 from exo.helpers import DEBUG, DEBUG_DISCOVERY, get_all_ip_addresses
 
+
 class ListenProtocol(asyncio.DatagramProtocol):
   def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]):
     super().__init__()
@@ -90,17 +91,13 @@ class UDPDiscovery(Discovery):
           "node_id": self.node_id,
           "grpc_port": self.node_port,
           "device_capabilities": self.device_capabilities.to_dict(),
-          "priority": 1, # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
+          "priority": 1,  # For now, every interface has the same priority. We can make this better by prioriting interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
         })
         if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr}): {message}")
 
         transport = None
         try:
-          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
-            lambda: BroadcastProtocol(message, self.broadcast_port),
-            local_addr=(addr, 0),
-            family=socket.AF_INET
-          )
+          transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
           if DEBUG_DISCOVERY >= 3:
             print(f"Broadcasting presence at ({addr})")
         except Exception as e:
@@ -145,7 +142,8 @@ class UDPDiscovery(Discovery):
         if peer_id in self.known_peers:
           existing_peer_prio = self.known_peers[peer_id][3]
           if existing_peer_prio >= peer_prio:
-            if DEBUG >= 1: print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
+            if DEBUG >= 1:
+              print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}")
             return
         new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", device_capabilities)
         if not await new_peer_handle.health_check():
@@ -161,8 +159,7 @@ class UDPDiscovery(Discovery):
         if peer_id in self.known_peers: self.known_peers[peer_id] = (self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio)
 
   async def task_listen_for_peers(self):
-    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message),
-                                                            local_addr=("0.0.0.0", self.listen_port))
+    await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=("0.0.0.0", self.listen_port))
     if DEBUG_DISCOVERY >= 2: print("Started listen task")
 
   async def task_cleanup_peers(self):
@@ -177,7 +174,13 @@ class UDPDiscovery(Discovery):
         for peer_id, should_remove in zip(peer_ids, results):
           if should_remove: peers_to_remove.append(peer_id)
 
-        if DEBUG_DISCOVERY >= 2: print("Peer statuses:", { peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}" for peer_handle, connected_at, last_seen, prio in self.known_peers.values() })
+        if DEBUG_DISCOVERY >= 2:
+          print(
+            "Peer statuses:", {
+              peer_handle.id(): f"is_connected={await peer_handle.is_connected()}, health_check={await peer_handle.health_check()}, connected_at={connected_at}, last_seen={last_seen}, prio={prio}"
+              for peer_handle, connected_at, last_seen, prio in self.known_peers.values()
+            }
+          )
 
         for peer_id in peers_to_remove:
           if peer_id in self.known_peers:
@@ -200,9 +203,5 @@ class UDPDiscovery(Discovery):
       if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
       return True
 
-    should_remove = (
-      (not is_connected and current_time - connected_at > self.discovery_timeout) or
-      (current_time - last_seen > self.discovery_timeout) or
-      (not health_ok)
-    )
+    should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
     return should_remove

+ 15 - 26
exo/orchestration/standard_node.py

@@ -18,6 +18,7 @@ from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader
 
+
 class StandardNode(Node):
   def __init__(
     self,
@@ -87,18 +88,14 @@ class StandardNode(Node):
   def get_supported_inference_engines(self):
     supported_engine_names = []
     if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
-        supported_engine_names.append('mlx')
-        supported_engine_names.append('tinygrad')
+      supported_engine_names.append('mlx')
+      supported_engine_names.append('tinygrad')
     else:
-        supported_engine_names.append('tinygrad')
+      supported_engine_names.append('tinygrad')
     return supported_engine_names
 
   async def broadcast_supported_engines(self, supported_engines_names: List[str]):
-    status_message = json.dumps({
-        "type": "supported_inference_engines",
-        "node_id": self.id,
-        "engines": supported_engines_names
-    })
+    status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
     await self.broadcast_opaque_status("", status_message)
 
   def get_topology_inference_engines(self) -> List[List[str]]:
@@ -311,20 +308,16 @@ class StandardNode(Node):
     next_peer_ids = {peer.id() for peer in next_peers}
     peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
     peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
-    peers_updated = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
-    peers_unchanged = [
-      peer for peer in next_peers
-      if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())
-    ]
+    peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
+    peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
     peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
     peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
 
     def _pretty(peers: List[PeerHandle]) -> List[str]:
       return [f"{peer.id()}@{peer.addr()}" for peer in peers]
-    if DEBUG >= 2: print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
+
+    if DEBUG >= 2:
+      print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
 
     async def disconnect_with_timeout(peer, timeout=5):
       try:
@@ -344,14 +337,8 @@ class StandardNode(Node):
         traceback.print_exc()
         return False
 
-    disconnect_results = await asyncio.gather(
-      *(disconnect_with_timeout(peer) for peer in peers_to_disconnect),
-      return_exceptions=True
-    )
-    connect_results = await asyncio.gather(
-      *(connect_with_timeout(peer) for peer in peers_to_connect),
-      return_exceptions=True
-    )
+    disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
+    connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
 
     successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
     failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
@@ -375,7 +362,7 @@ class StandardNode(Node):
         self.inference_engine = get_inference_engine("tinygrad", self.shard_downloader)
       else:
         if DEBUG >= 1: print("All nodes can use mlx, using mlx for inference")
-        self.inference_engine = get_inference_engine("mlx", self.shard_downloader) 
+        self.inference_engine = get_inference_engine("mlx", self.shard_downloader)
 
   async def periodic_topology_collection(self, interval: int):
     while True:
@@ -422,6 +409,7 @@ class StandardNode(Node):
         self.topology.merge(other_topology)
       except Exception as e:
         print(f"Error collecting topology from {peer.id()}: {e}")
+        traceback.print_exc()
 
     next_topology.active_node_id = self.topology.active_node_id  # this is not so clean.
     self.topology = next_topology
@@ -464,6 +452,7 @@ class StandardNode(Node):
       except Exception as e:
         print(f"Error sending opaque status to {peer.id()}: {e}")
         traceback.print_exc()
+
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     # in the case of opaque status, we also want to receive our own opaque statuses
     self.on_opaque_status.trigger_all(request_id, status)

+ 44 - 41
exo/tinychat/update_deps.py

@@ -4,49 +4,52 @@ from bs4 import BeautifulSoup
 from urllib.parse import urljoin, urlparse
 import re
 
+
 def download_file(url, local_path):
-    response = requests.get(url)
-    if response.status_code == 200:
-        os.makedirs(os.path.dirname(local_path), exist_ok=True)
-        with open(local_path, 'wb') as f:
-            f.write(response.content)
-        print(f"Downloaded: {local_path}")
-    else:
-        print(response.status_code)
-        print(f"Failed to download: {url}")
+  response = requests.get(url)
+  if response.status_code == 200:
+    os.makedirs(os.path.dirname(local_path), exist_ok=True)
+    with open(local_path, 'wb') as f:
+      f.write(response.content)
+    print(f"Downloaded: {local_path}")
+  else:
+    print(response.status_code)
+    print(f"Failed to download: {url}")
+
 
 def update_html(html_content, base_url):
-    soup = BeautifulSoup(html_content, 'html.parser')
+  soup = BeautifulSoup(html_content, 'html.parser')
 
-    for tag in soup.find_all(['script', 'link']):
-        if tag.has_attr('src'):
-            url = tag['src']
-        elif tag.has_attr('href'):
-            url = tag['href']
-        else:
-            continue
+  for tag in soup.find_all(['script', 'link']):
+    if tag.has_attr('src'):
+      url = tag['src']
+    elif tag.has_attr('href'):
+      url = tag['href']
+    else:
+      continue
+
+    if url.startswith(('http://', 'https://')):
+      full_url = url
+    else:
+      full_url = urljoin(base_url, url)
 
-        if url.startswith(('http://', 'https://')):
-            full_url = url
-        else:
-            full_url = urljoin(base_url, url)
+    parsed_url = urlparse(full_url)
+    local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
 
-        parsed_url = urlparse(full_url)
-        local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/'))
+    download_file(full_url, local_path)
 
-        download_file(full_url, local_path)
+    relative_path = os.path.relpath(local_path, '.')
+    if tag.name == 'script':
+      tag['src'] = "/" + relative_path
+    elif tag.name == 'link':
+      tag['href'] = "/" + relative_path
 
-        relative_path = os.path.relpath(local_path, '.')
-        if tag.name == 'script':
-            tag['src'] = "/" + relative_path
-        elif tag.name == 'link':
-            tag['href'] = "/" + relative_path
+  return str(soup)
 
-    return str(soup)
 
 # Read the HTML file
 with open('./index.html', 'r') as f:
-    html_content = f.read()
+  html_content = f.read()
 
 # Update HTML and download files
 # updated_html = update_html(html_content, 'https://example.com')
@@ -68,7 +71,7 @@ download_file(css_url, css_output_path)
 
 # Parse CSS file for font URLs
 with open(css_output_path, 'r', encoding='utf-8') as f:
-    css_content = f.read()
+  css_content = f.read()
 
 # Extract font URLs from the CSS content
 font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content)
@@ -77,14 +80,14 @@ print(f"Found {len(font_urls)} font URLs")
 
 # Download font files
 for font_url in font_urls:
-    font_url = font_url.strip('"\'')
-    if font_url.startswith('../'):
-        font_url = font_url[3:]
+  font_url = font_url.strip('"\'')
+  if font_url.startswith('../'):
+    font_url = font_url[3:]
 
-    # Use base_url instead of urljoin to keep the version number
-    full_url = base_url + font_url
-    relative_path = font_url
-    output_path = os.path.join(output_dir, relative_path)
-    download_file(full_url, output_path)
+  # Use base_url instead of urljoin to keep the version number
+  full_url = base_url + font_url
+  relative_path = font_url
+  output_path = os.path.join(output_dir, relative_path)
+  download_file(full_url, output_path)
 
-print("Download complete!")
+print("Download complete!")

+ 1 - 1
format.py

@@ -21,7 +21,7 @@ def run_yapf(target):
 
 def main():
   if len(sys.argv) < 2:
-    print("Usage: python format.py <directory_or_file>")
+    print("Usage: python3 format.py <directory_or_file> e.g. python3 format.py ./exo")
     sys.exit(1)
 
   target = sys.argv[1]

+ 0 - 5
lint.sh

@@ -1,5 +0,0 @@
-#!/bin/bash
-
-pip3 install -e '.[linting]'
-python3 -m ruff check .
-python3 -m pylint .

+ 0 - 7
pyproject.toml

@@ -1,7 +0,0 @@
-[tool.pylint.format]
-indent-string = '  '
-max-line-length = 200
-
-[tool.autopep8]
-max_line_length = 200
-indent_size = 2

+ 0 - 43
ruff.toml

@@ -1,43 +0,0 @@
-indent-width = 2
-preview = true
-target-version = "py312"
-
-lint.select = [
-  "F",  # Pyflakes
-  "W6",
-  "E71",
-  "E72",
-  "E112",   # no-indented-block
-  "E113",   # unexpected-indentation
-  # "E124",
-  "E203",   # whitespace-before-punctuation
-  "E272",   # multiple-spaces-before-keyword
-  "E303",   # too-many-blank-lines
-  "E304",   # blank-line-after-decorator
-  "E501",   # line-too-long
-  # "E502",
-  "E702",   # multiple-statements-on-one-line-semicolon
-  "E703",   # useless-semicolon
-  "E731",   # lambda-assignment
-  "W191",   # tab-indentation
-  "W291",   # trailing-whitespace
-  "W293",   # blank-line-with-whitespace
-  "UP039",  # unnecessary-class-parentheses
-  "C416",   # unnecessary-comprehension
-  "RET506", # superfluous-else-raise
-  "RET507", # superfluous-else-continue
-  "A",      # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
-  "SIM105", # suppressible-exception
-  "FURB110",# if-exp-instead-of-or-operator
-]
-
-line-length = 200
-
-exclude = [
-  "docs/",
-  "examples/",
-  "extra/",
-  "exo/networking/grpc/node_service_pb2.py",
-  "exo/networking/grpc/node_service_pb2_grpc.py",
-  "exo/helpers.py",
-]

+ 1 - 4
setup.py

@@ -30,10 +30,7 @@ install_requires = [
 ]
 
 extras_require = {
-  "linting": [
-    "pylint==3.2.6",
-    "ruff==0.5.5",
-    "mypy==1.11.0",
+  "formatting": [
     "yapf==0.40.2",
   ],
   "apple_silicon": [

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff