Ver Fonte

Merge branch 'main' of github.com:xeb/exo

Mark Kockerbeck há 1 ano atrás
pai
commit
f1cd5ae7a6

+ 3 - 1
README.md

@@ -109,7 +109,9 @@ That's it! No configuration required - exo will automatically discover the other
 
 The native way to access models running on exo is using the exo library with peer handles. See how in [this example for Llama 3](examples/llama3_distributed.py).
 
-exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000. Note: this is currently only supported by tail nodes (i.e. nodes selected to be at the end of the ring topology). Example request:
+exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:8000
+
+For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000/v1/chat/completions. Example with curl:
 
 ```sh
 curl http://localhost:8000/v1/chat/completions \

+ 0 - 1
examples/llama3_distributed.py

@@ -50,7 +50,6 @@ async def run_prompt(prompt: str):
         )
 
     await peer2.connect()
-    await peer2.global_reset(shard, set(), 2)
 
     try:
         await peer2.send_prompt(shard, prompt, request_id)

+ 44 - 11
exo/api/chatgpt_api.py

@@ -13,10 +13,7 @@ from exo.inference.shard import Shard
 from exo.orchestration import Node
 
 shard_mappings = {
-    "llama-3-8b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
-    },
+    ### llama
     "llama-3.1-8b": {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
     },
@@ -24,12 +21,23 @@ shard_mappings = {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
     },
     "llama-3.1-405b": {
-        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
+        "MLXDynamicShardInferenceEngine": Shard(model_id="/Users/alex/405b-instruct-4bit", start_layer=0, end_layer=0, n_layers=126),
+    },
+    "llama-3-8b": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
     },
     "llama-3-70b": {
         "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
         "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
     },
+    ### mistral
+    "mistral-nemo": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Nemo-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=40),
+    },
+    "mistral-large": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Mistral-Large-Instruct-2407-4bit", start_layer=0, end_layer=0, n_layers=88),
+    },
 }
 
 class Message:
@@ -124,6 +132,17 @@ def build_prompt(tokenizer, messages: List[Message]):
         messages, tokenize=False, add_generation_prompt=True
     )
 
+def parse_message(data: dict):
+    if 'role' not in data or 'content' not in data:
+        raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
+    return Message(data['role'], data['content'])
+
+def parse_chat_request(data: dict):
+    return ChatCompletionRequest(
+        data.get('model', 'llama-3.1-8b'),
+        [parse_message(msg) for msg in data['messages']],
+        data.get('temperature', 0.0)
+    )
 
 class ChatGPTAPI:
     def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
@@ -150,32 +169,46 @@ class ChatGPTAPI:
         self.app.router.add_get('/', self.handle_root)
         self.app.router.add_static('/', self.static_dir, name='static')
 
+        # Add middleware to log every request
+        self.app.middlewares.append(self.log_request)
+
+    async def log_request(self, app, handler):
+        async def middleware(request):
+            if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
+            return await handler(request)
+        return middleware
+
     async def handle_root(self, request):
+        print(f"Handling root request from {request.remote}")
         return web.FileResponse(self.static_dir / 'index.html')
 
     async def handle_post_chat_token_encode(self, request):
         data = await request.json()
         shard = shard_mappings.get(data.get('model', 'llama-3.1-8b'), {}).get(self.inference_engine_classname)
-        messages = data.get('messages', [])
+        messages = [parse_message(msg) for msg in data.get('messages', [])]
         tokenizer = await resolve_tokenizer(shard.model_id)
         return web.json_response({'length': len(build_prompt(tokenizer, messages))})
 
     async def handle_post_chat_completions(self, request):
         data = await request.json()
+        if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
         stream = data.get('stream', False)
-        messages = [Message(**msg) for msg in data['messages']]
-        chat_request = ChatCompletionRequest(data.get('model', 'llama-3.1-8b'), messages, data.get('temperature', 0.0))
+        chat_request = parse_chat_request(data)
         if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
             chat_request.model = "llama-3.1-8b"
-        shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
+        if not chat_request.model or chat_request.model not in shard_mappings:
+            if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}. Defaulting to llama-3.1-8b")
+            chat_request.model = "llama-3.1-8b"
+        shard = shard_mappings[chat_request.model].get(self.inference_engine_classname, None)
         if not shard:
-            return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
+            supported_models = [model for model, engines in shard_mappings.items() if self.inference_engine_classname in engines]
+            return web.json_response({'detail': f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"}, status=400)
         request_id = str(uuid.uuid4())
 
         tokenizer = await resolve_tokenizer(shard.model_id)
         if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
 
-        prompt = build_prompt(tokenizer, messages)
+        prompt = build_prompt(tokenizer, chat_request.messages)
         callback_id = f"chatgpt-api-wait-response-{request_id}"
         callback = self.node.on_token.register(callback_id)
 

+ 6 - 11
exo/inference/debug_inference_engine.py

@@ -12,18 +12,13 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
     _tokenizer = Tokenizer(str(Path(model_id) / "tokenizer.model"))
 
     prompt = "In a single word only, what is the last name of the president of the United States? "
-    resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-    next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
+    resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+    next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
 
-    await inference_engine_1.reset_shard(shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32))
-    resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-
-    await inference_engine_2.reset_shard(shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32))
-    resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state_1)
-
-    # don't reset the second time
-    resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=resp2, inference_state=inference_state_2)
-    resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, inference_state=inference_state_3)
+    resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+    resp2, inference_state_2, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state_1)
+    resp3, inference_state_3, _ = await inference_engine_1.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=resp2, inference_state=inference_state_2)
+    resp4, inference_state_4, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, inference_state=inference_state_3)
 
     print(f"{resp2=}")
     print(f"full: {_tokenizer.decode(resp_full)}")

+ 2 - 6
exo/inference/inference_engine.py

@@ -6,13 +6,9 @@ from .shard import Shard
 
 class InferenceEngine(ABC):
     @abstractmethod
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
+    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
         pass
 
     @abstractmethod
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
-        pass
-
-    @abstractmethod
-    async def reset_shard(self, shard: Shard):
+    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
         pass

+ 4 - 8
exo/inference/mlx/sharded_inference_engine.py

@@ -10,20 +10,16 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     def __init__(self):
         self.shard = None
 
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(self.tokenizer.encode(prompt))))
+        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(self.tokenizer.encode(prompt))))
         return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
-        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(mx.array(input_data)))
+        output_data: np.ndarray = np.array(self.stateful_sharded_model.step(request_id, mx.array(input_data)))
         return output_data, "", output_data.size == 1 and output_data.item() == self.tokenizer.eos_token_id
 
-    async def reset_shard(self, shard: Shard):
-        await self.ensure_shard(shard)
-        self.stateful_sharded_model.reset()
-
     async def ensure_shard(self, shard: Shard):
         if self.shard == shard:
             return

+ 7 - 4
exo/inference/mlx/sharded_model.py

@@ -11,10 +11,11 @@ class StatefulShardedModel:
     def __init__(self, shard: Shard, model: nn.Module):
         self.shard = shard
         self.model = model
-        self.reset()
+        self.request_cache: Dict[str, Tuple[str, KVCache]] = {}
 
     def step(
         self,
+        request_id: str,
         x,
         temp: float = 0.0,
         top_p: float = 1.0,
@@ -38,7 +39,9 @@ class StatefulShardedModel:
 
         y = x
 
-        output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.cache)
+        if request_id not in self.request_cache:
+            self.init_cache(request_id)
+        output = self.model(y[None] if self.shard.is_first_layer() else y, cache=self.request_cache[request_id])
 
         if self.shard.is_last_layer():
             logits = output[:, -1, :]
@@ -56,10 +59,10 @@ class StatefulShardedModel:
     ) -> Generator[Tuple[mx.array, mx.array], None, None]:
         return self.step(x, temp, top_p, logit_bias)
 
-    def reset(self):
+    def init_cache(self, request_id: str):
         kv_heads = (
             [self.model.n_kv_heads] * len(self.model.layers)
             if isinstance(self.model.n_kv_heads, int)
             else self.model.n_kv_heads
         )
-        self.cache = [KVCache(self.model.head_dim, n) for n in kv_heads]
+        self.request_cache[request_id] = [KVCache(self.model.head_dim, n) for n in kv_heads]

+ 3 - 9
exo/inference/mlx/sharded_utils.py

@@ -25,8 +25,8 @@ class ModelNotFoundError(Exception):
         super().__init__(self.message)
 
 MODEL_REMAPPING = {
-    "mistral": "llama",  # mistral is compatible with llama
-    "phi-msft": "phixtral",
+    "sharded_mistral": "sharded_llama",  # mistral is compatible with llama
+    "sharded_phi-msft": "sharded_phixtral",
 }
 
 def _get_classes(config: dict):
@@ -122,16 +122,10 @@ def load_model_shard(
         weights = model.sanitize(weights)
 
     if (quantization := config.get("quantization", None)) is not None:
-        # Handle legacy models which may not have everything quantized
-        def class_predicate(p, m):
-            if not hasattr(m, "to_quantized"):
-                return False
-            return f"{p}.scales" in all_weights_keys
-
         nn.quantize(
             model,
             **quantization,
-            class_predicate=class_predicate,
+            class_predicate=None,
         )
 
     filtered_weights = {}

+ 6 - 11
exo/inference/test_inference_engine.py

@@ -8,18 +8,13 @@ import numpy as np
 # An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
 async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
     prompt = "In a single word only, what is the last name of the current president of the USA?"
-    resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
-    next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
+    resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
+    next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)
 
-    await inference_engine_1.reset_shard(shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32))
-    resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
-
-    await inference_engine_2.reset_shard(shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32))
-    resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state_1)
-
-    # don't reset the second time
-    resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=resp2, inference_state=inference_state_2)
-    resp4, inference_state_4, _ = await inference_engine_2.infer_tensor(shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, inference_state=inference_state_3)
+    resp1, inference_state_1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
+    resp2, inference_state_2, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, inference_state=inference_state_1)
+    resp3, inference_state_3, _ = await inference_engine_1.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=resp2, inference_state=inference_state_2)
+    resp4, inference_state_4, _ = await inference_engine_2.infer_tensor("B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, inference_state=inference_state_3)
 
     assert np.array_equal(resp_full, resp2)
     assert np.array_equal(next_resp_full, resp4)

+ 3 - 7
exo/inference/tinygrad/inference.py

@@ -143,7 +143,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     def __init__(self):
         self.shard = None
 
-    async def infer_prompt(self, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+    async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+        # TODO: we need to refactor models/llamaa to handle per-request-kv-cache. right now it's shared between requests.
         await self.ensure_shard(shard)
         start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
@@ -157,7 +158,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
         return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
 
-    async def infer_tensor(self, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
+    async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
         await self.ensure_shard(shard)
         start_pos = json.loads(inference_state).get("start_pos", 0) if inference_state else 0
 
@@ -167,11 +168,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
 
         return output_data, json.dumps({"start_pos": start_pos}), output_data.size == 1 and output_data.item() in self.tokenizer.stop_tokens
 
-    async def reset_shard(self, shard: Shard):
-        await self.ensure_shard(shard)
-
-        self.model.reset()
-
     async def ensure_shard(self, shard: Shard):
         if self.shard == shard:
             return

+ 0 - 8
exo/networking/grpc/grpc_peer_handle.py

@@ -74,10 +74,6 @@ class GRPCPeerHandle(PeerHandle):
             return None, response.is_finished
         return np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape), response.is_finished
 
-    async def reset_shard(self, shard: Shard) -> None:
-        request = node_service_pb2.ResetShardRequest(shard=node_service_pb2.Shard(model_id=shard.model_id, start_layer=shard.start_layer, end_layer=shard.end_layer, n_layers=shard.n_layers))
-        await self.stub.ResetShard(request)
-
     async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
         request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
         response = await self.stub.CollectTopology(request)
@@ -90,10 +86,6 @@ class GRPCPeerHandle(PeerHandle):
                 topology.add_edge(node_id, peer_id)
         return topology
 
-    async def global_reset(self, base_shard: Shard, visited: set[str], max_depth: int) -> None:
-        request = node_service_pb2.GlobalResetRequest(base_shard=node_service_pb2.Shard(model_id=base_shard.model_id, start_layer=base_shard.start_layer, end_layer=base_shard.end_layer, n_layers=base_shard.n_layers), visited=visited, max_depth=max_depth)
-        await self.stub.GlobalReset(request)
-
     async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
         request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
         await self.stub.SendResult(request)

+ 0 - 14
exo/networking/grpc/grpc_server.py

@@ -60,12 +60,6 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         tensor_data = result[0].tobytes() if result[0] is not None else None
         return node_service_pb2.InferenceResult(tensor=node_service_pb2.Tensor(tensor_data=tensor_data, shape=result[0].shape, dtype=str(result[0].dtype)), is_finished=result[1]) if result[0] is not None else node_service_pb2.InferenceResult(is_finished=result[1])
 
-    async def ResetShard(self, request, context):
-        shard = Shard(model_id=request.shard.model_id, start_layer=request.shard.start_layer, end_layer=request.shard.end_layer, n_layers=request.shard.n_layers)
-        if DEBUG >= 2: print(f"Received ResetShard request: {shard}")
-        await self.node.reset_shard(shard)
-        return node_service_pb2.Empty()
-
     async def CollectTopology(self, request, context):
         max_depth = request.max_depth
         visited = set(request.visited)
@@ -75,14 +69,6 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
         if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
         return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
 
-    async def GlobalReset(self, request, context):
-        base_shard = Shard(model_id=request.base_shard.model_id, start_layer=request.base_shard.start_layer, end_layer=request.base_shard.end_layer, n_layers=request.base_shard.n_layers)
-        visited = set(request.visited)
-        max_depth = request.max_depth
-        if DEBUG >= 2: print(f"Received GlobalReset request: {base_shard=} {visited=} {max_depth=}")
-        await self.node.global_reset(base_shard, visited, max_depth)
-        return node_service_pb2.Empty()
-
     async def SendResult(self, request, context):
         request_id = request.request_id
         result = request.result

+ 0 - 12
exo/networking/grpc/node_service.proto

@@ -5,10 +5,8 @@ package node_service;
 service NodeService {
   rpc SendPrompt (PromptRequest) returns (Tensor) {}
   rpc SendTensor (TensorRequest) returns (Tensor) {}
-  rpc ResetShard (ResetShardRequest) returns (Empty) {}
   rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
-  rpc GlobalReset (GlobalResetRequest) returns (Empty) {}
   rpc SendResult (SendResultRequest) returns (Empty) {}
   rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
 }
@@ -49,21 +47,11 @@ message Tensor {
   string dtype = 3;
 }
 
-message ResetShardRequest {
-  Shard shard = 1;
-}
-
 message CollectTopologyRequest {
   repeated string visited = 1;
   int32 max_depth = 2;
 }
 
-message GlobalResetRequest {
-  Shard base_shard = 1;
-  repeated string visited = 2;
-  int32 max_depth = 3;
-}
-
 message Topology {
   map<string, DeviceCapabilities> nodes = 1;
   map<string, Peers> peer_graph = 2;

Diff do ficheiro suprimidas por serem muito extensas
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 0 - 86
exo/networking/grpc/node_service_pb2_grpc.py

@@ -49,11 +49,6 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.TensorRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Tensor.FromString,
                 _registered_method=True)
-        self.ResetShard = channel.unary_unary(
-                '/node_service.NodeService/ResetShard',
-                request_serializer=node__service__pb2.ResetShardRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
         self.GetInferenceResult = channel.unary_unary(
                 '/node_service.NodeService/GetInferenceResult',
                 request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
@@ -64,11 +59,6 @@ class NodeServiceStub(object):
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Topology.FromString,
                 _registered_method=True)
-        self.GlobalReset = channel.unary_unary(
-                '/node_service.NodeService/GlobalReset',
-                request_serializer=node__service__pb2.GlobalResetRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Empty.FromString,
-                _registered_method=True)
         self.SendResult = channel.unary_unary(
                 '/node_service.NodeService/SendResult',
                 request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
@@ -96,12 +86,6 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
-    def ResetShard(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
     def GetInferenceResult(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -114,12 +98,6 @@ class NodeServiceServicer(object):
         context.set_details('Method not implemented!')
         raise NotImplementedError('Method not implemented!')
 
-    def GlobalReset(self, request, context):
-        """Missing associated documentation comment in .proto file."""
-        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
-        context.set_details('Method not implemented!')
-        raise NotImplementedError('Method not implemented!')
-
     def SendResult(self, request, context):
         """Missing associated documentation comment in .proto file."""
         context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -145,11 +123,6 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.TensorRequest.FromString,
                     response_serializer=node__service__pb2.Tensor.SerializeToString,
             ),
-            'ResetShard': grpc.unary_unary_rpc_method_handler(
-                    servicer.ResetShard,
-                    request_deserializer=node__service__pb2.ResetShardRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
             'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
                     servicer.GetInferenceResult,
                     request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
@@ -160,11 +133,6 @@ def add_NodeServiceServicer_to_server(servicer, server):
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
                     response_serializer=node__service__pb2.Topology.SerializeToString,
             ),
-            'GlobalReset': grpc.unary_unary_rpc_method_handler(
-                    servicer.GlobalReset,
-                    request_deserializer=node__service__pb2.GlobalResetRequest.FromString,
-                    response_serializer=node__service__pb2.Empty.SerializeToString,
-            ),
             'SendResult': grpc.unary_unary_rpc_method_handler(
                     servicer.SendResult,
                     request_deserializer=node__service__pb2.SendResultRequest.FromString,
@@ -240,33 +208,6 @@ class NodeService(object):
             metadata,
             _registered_method=True)
 
-    @staticmethod
-    def ResetShard(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/ResetShard',
-            node__service__pb2.ResetShardRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
     @staticmethod
     def GetInferenceResult(request,
             target,
@@ -321,33 +262,6 @@ class NodeService(object):
             metadata,
             _registered_method=True)
 
-    @staticmethod
-    def GlobalReset(request,
-            target,
-            options=(),
-            channel_credentials=None,
-            call_credentials=None,
-            insecure=False,
-            compression=None,
-            wait_for_ready=None,
-            timeout=None,
-            metadata=None):
-        return grpc.experimental.unary_unary(
-            request,
-            target,
-            '/node_service.NodeService/GlobalReset',
-            node__service__pb2.GlobalResetRequest.SerializeToString,
-            node__service__pb2.Empty.FromString,
-            options,
-            channel_credentials,
-            insecure,
-            call_credentials,
-            compression,
-            wait_for_ready,
-            timeout,
-            metadata,
-            _registered_method=True)
-
     @staticmethod
     def SendResult(request,
             target,

+ 0 - 8
exo/networking/peer_handle.py

@@ -38,18 +38,10 @@ class PeerHandle(ABC):
     async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
         pass
 
-    @abstractmethod
-    async def reset_shard(self, shard: Shard) -> None:
-        pass
-
     @abstractmethod
     async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
         pass
 
-    @abstractmethod
-    async def global_reset(self, base_shard: Shard, visited: set[str], max_depth: int) -> None:
-        pass
-
     @abstractmethod
     async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
         pass

+ 0 - 8
exo/orchestration/node.py

@@ -22,10 +22,6 @@ class Node(ABC):
     async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
         pass
 
-    @abstractmethod
-    async def reset_shard(self, shard: Shard) -> None:
-        pass
-
     @abstractmethod
     async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
         pass
@@ -34,10 +30,6 @@ class Node(ABC):
     async def collect_topology(self, visited: set[str] = set(), max_depth: int = 2) -> Topology:
         pass
 
-    @abstractmethod
-    async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
-        pass
-
     @property
     @abstractmethod
     def current_topology(self) -> Topology:

+ 2 - 33
exo/orchestration/standard_node.py

@@ -79,7 +79,7 @@ class StandardNode(Node):
             await self.forward_to_next_shard(shard, prompt, request_id)
             return
 
-        result, inference_state, is_finished = await self.inference_engine.infer_prompt(shard, prompt, inference_state=inference_state)
+        result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
         is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
         if is_finished:
             self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -115,7 +115,7 @@ class StandardNode(Node):
 
         try:
             if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
-            result, inference_state, is_finished = await self.inference_engine.infer_tensor(shard, tensor, inference_state=inference_state)
+            result, inference_state, is_finished = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
             is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
             if is_finished:
                 self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -178,12 +178,6 @@ class StandardNode(Node):
             raise ValueError(f"No current partition found for node: {self.id}")
         return shards[current_partition_index]
 
-    async def reset_shard(self, base_shard: Shard) -> None:
-        # Implement shard reset logic
-        if DEBUG >= 2: print(f"Resetting shard: {base_shard}")
-        self.buffered_token_output = {}
-        await self.inference_engine.reset_shard(self.get_current_shard(base_shard))
-
     async def update_peers(self, wait_for_peers: int = 0) -> None:
         self.peers = await self.discovery.discover_peers(wait_for_peers)
         if DEBUG >= 2: print(f"Starting with the following peers: {self.peers}")
@@ -245,31 +239,6 @@ class StandardNode(Node):
         if self.topology_viz: self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology))
         return next_topology
 
-    # TODO: unify this and collect_topology as global actions
-    async def global_reset(self, base_shard: Shard, visited: set[str] = set(), max_depth: int = 2) -> None:
-        shard = self.get_current_shard(base_shard)
-        await self.reset_shard(shard)
-
-        if DEBUG >= 2: print(f"Global reset {base_shard=} {max_depth=} {visited=}")
-
-        prev_visited = visited.copy()
-        visited.update(p.id() for p in self.peers)
-
-        for peer in self.peers:
-            if peer.id() in prev_visited:
-                if DEBUG >= 2: print(f"Already visited {peer.id()}. Skipping...")
-                continue
-
-            if max_depth <= 0:
-                if DEBUG >= 2: print(f"Max depth reached. Skipping...")
-                continue
-
-            try:
-                print(f"Forwarding global reset to peer {peer.id()}")
-                await peer.global_reset(base_shard, visited, max_depth = max_depth - 1)
-            except Exception as e:
-                print(f"Error collecting topology from {peer.id()}: {e}")
-
     @property
     def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
         return self._on_token

+ 23 - 0
tinychat/examples/tinychat/index.html

@@ -30,6 +30,27 @@
 
   <link rel="stylesheet" href="index.css">
   <link rel="stylesheet" href="common.css">
+
+  <style>
+    .model-selector {
+      display: flex;
+      justify-content: center;
+      padding: 20px 0;
+    }
+    .model-selector select {
+      padding: 10px 20px;
+      font-size: 16px;
+      border: 1px solid #ccc;
+      border-radius: 5px;
+      background-color: #f8f8f8;
+      cursor: pointer;
+    }
+    .model-selector select:focus {
+      outline: none;
+      border-color: #007bff;
+      box-shadow: 0 0 0 2px rgba(0,123,255,.25);
+    }
+  </style>
 </head>
 
 <body>
@@ -41,6 +62,8 @@
         <option value="llama-3.1-405b">Llama 3.1 405B</option>
         <option value="llama-3-8b">Llama 3 8B</option>
         <option value="llama-3-70b">Llama 3 70B</option>
+        <option value="mistral-nemo">Mistral Nemo</option>
+        <option value="mistral-large">Mistral Large</option>
       </select>
     </div>
     <div class="home centered" x-show="home === 0" x-transition x-effect="

Alguns ficheiros não foram mostrados porque muitos ficheiros mudaram neste diff