Quellcode durchsuchen

optimise networking, remove bloat

Alex Cheema vor 4 Monaten
Ursprung
Commit
c9ded9ba96

+ 63 - 50
exo/api/chatgpt_api.py

@@ -21,6 +21,7 @@ from exo.download.hf.hf_shard_download import HFShardDownloader
 import shutil
 from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
 from exo.apputil import create_animation_mp4
+from collections import defaultdict
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -160,6 +161,11 @@ class ChatGPTAPI:
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
     self.default_model = default_model or "llama-3.2-1b"
+    self.token_queues = defaultdict(asyncio.Queue)
+
+    # Get the callback system and register our handler
+    self.token_callback = node.on_token.register("chatgpt-api-token-handler")
+    self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
 
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
@@ -346,9 +352,6 @@ class ChatGPTAPI:
     #   request_id = str(uuid.uuid4())
     #   self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
 
-    callback_id = f"chatgpt-api-wait-response-{request_id}"
-    callback = self.node.on_token.register(callback_id)
-
     if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
 
     try:
@@ -367,53 +370,63 @@ class ChatGPTAPI:
         )
         await response.prepare(request)
 
-        async def stream_result(_request_id: str, token: int, is_finished: bool):
-          finish_reason = None
-          eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
-                                                                                                                             AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
-          if token == eos_token_id:
+        try:
+          # Stream tokens while waiting for inference to complete
+          while True:
+            token, is_finished = await asyncio.wait_for(
+              self.token_queues[request_id].get(),
+              timeout=self.response_timeout
+            )
+
+            finish_reason = None
+            eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
+
+            if token == eos_token_id:
+              if is_finished:
+                finish_reason = "stop"
+            if is_finished and not finish_reason:
+              finish_reason = "length"
+
+            completion = generate_completion(
+              chat_request,
+              tokenizer,
+              prompt,
+              request_id,
+              [token],
+              stream,
+              finish_reason,
+              "chat.completion",
+            )
+
+            await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+
             if is_finished:
-              finish_reason = "stop"
-          if is_finished and not finish_reason:
-            finish_reason = "length"
-
-          completion = generate_completion(
-            chat_request,
-            tokenizer,
-            prompt,
-            request_id,
-            [token],
-            stream,
-            finish_reason,
-            "chat.completion",
+              break
+
+          await response.write_eof()
+          return response
+
+        except asyncio.TimeoutError:
+          return web.json_response({"detail": "Response generation timed out"}, status=408)
+
+        except Exception as e:
+          if DEBUG >= 2: traceback.print_exc()
+          return web.json_response(
+            {"detail": f"Error processing prompt: {str(e)}"},
+            status=500
           )
-          if DEBUG >= 2: print(f"Streaming completion: {completion}")
-          try:
-            await response.write(f"data: {json.dumps(completion)}\n\n".encode())
-          except Exception as e:
-            if DEBUG >= 2: print(f"Error streaming completion: {e}")
-            if DEBUG >= 2: traceback.print_exc()
-
-        def on_result(_request_id: str, token: int, is_finished: bool):
-          if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, token, is_finished))
-
-          return _request_id == request_id and is_finished
-
-        _, token, _ = await callback.wait(on_result, timeout=self.response_timeout)
-        if request_id in self.stream_tasks:  # in case there is still a stream task running, wait for it to complete
-          if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
-          try:
-            await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
-          except asyncio.TimeoutError:
-            print("WARNING: Stream task timed out. This should not happen.")
-        await response.write_eof()
-        return response
-      else:
-        _, token, _ = await callback.wait(
-          lambda _request_id, token, is_finished: _request_id == request_id and is_finished,
-          timeout=self.response_timeout,
-        )
 
+        finally:
+          # Clean up the queue for this request
+          if request_id in self.token_queues:
+            del self.token_queues[request_id]
+      else:
+        tokens = []
+        while True:
+          token, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
+          tokens.append(token)
+          if is_finished:
+            break
         finish_reason = "length"
         eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
         if DEBUG >= 2: print(f"Checking if end of tokens result {token=} is {eos_token_id=}")
@@ -426,9 +439,6 @@ class ChatGPTAPI:
     except Exception as e:
       if DEBUG >= 2: traceback.print_exc()
       return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
-    finally:
-      deregistered_callback = self.node.on_token.deregister(callback_id)
-      if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
 
   async def handle_delete_model(self, request):
     try:
@@ -566,6 +576,9 @@ class ChatGPTAPI:
         status=500
       )
 
+  async def handle_token(self, request_id: str, token: int, is_finished: bool):
+    await self.token_queues[request_id].put((token, is_finished))
+
   async def run(self, host: str = "0.0.0.0", port: int = 52415):
     runner = web.AppRunner(self.app)
     await runner.setup()

+ 7 - 0
exo/inference/mlx/perf_improvements.md

@@ -0,0 +1,7 @@
+# Perf improvements
+
+Target: 460 tok/sec
+- removing sample goes from 369 -> 402
+- performance degrades as we generate more tokens
+- make mlx inference engien synchronous, removing thread pool executor: 402 -> 413
+- remove self.on_opaque_status.trigger_all: 413 -> 418

+ 21 - 28
exo/inference/mlx/sharded_inference_engine.py

@@ -1,7 +1,7 @@
 import numpy as np
 import mlx.core as mx
 import mlx.nn as nn
-from mlx_lm.sample_utils import top_p_sampling
+from mlx_lm.sample_utils import top_p_sampling, make_sampler
 import mlx.optimizers as optim
 from ..inference_engine import InferenceEngine
 from .sharded_utils import load_shard, get_image_from_str
@@ -10,8 +10,6 @@ from ..shard import Shard
 from typing import Dict, Optional, Tuple
 from exo.download.shard_download import ShardDownloader
 import asyncio
-from concurrent.futures import ThreadPoolExecutor
-from functools import partial
 from collections import OrderedDict
 from mlx_lm.models.cache import make_prompt_cache
 
@@ -40,61 +38,60 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
-    self.executor = ThreadPoolExecutor(max_workers=1)
     self.caches = OrderedDict()
+    self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
+    self.sampler = make_sampler(*self.sampler_params)
 
   async def poll_state(self, request_id: str, max_caches=2):
     if request_id in self.caches:
       self.caches.move_to_end(request_id)
     else:
-      newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
+      newcache = make_prompt_cache(self.model)
       if len(self.caches) > max_caches:
         self.caches.popitem(last=False)
       self.caches[request_id] = newcache
     return {"cache": self.caches[request_id]}
 
-  async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
-    y = mx.array(x)
-    logits = y[:, -1, :]
-    out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
-    return out
+  async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
+    if (temp, top_p, 0.0, 1) != self.sampler_params:
+      self.sampler_params = (temp, top_p, 0.0, 1)
+      self.sampler = make_sampler(*self.sampler_params)
+    logits = mx.array(x)
+    logits = logits[:, -1, :]
+    logprobs = logits - mx.logsumexp(logits, keepdims=True)
+    return np.asarray(self.sampler(logprobs), dtype=int)
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return np.array(tokens)
+    tokens = self.tokenizer.encode(prompt)
+    return np.asarray(tokens)
 
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
-    return tokens
+    return self.tokenizer.decode(tokens)
 
   async def save_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
-    await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
+    self.model.save_weights(path)
 
   async def load_checkpoint(self, shard: Shard, path: str):
     await self.ensure_shard(shard)
-    await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
+    self.model.load_weights(path)
     
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    loop = asyncio.get_running_loop()
     state = await self.poll_state(request_id)
     x = mx.array(input_data)
-    output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
+    output_data = np.array(self.model(x, **state), copy=False)
     return output_data
 
   async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
     await self.ensure_shard(shard)
     await self.save_session('loss', loss_fns[loss])
-    loop = asyncio.get_running_loop()
-    #print(f"evaluate in <- {inputs}")
     x = mx.array(inputs)
     y = mx.array(targets)
     l = mx.array(lengths)
-    score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
-    #print(f"evaluate out -> {score}")
+    score = self.session['loss'](self.model, x, y, l)
     return score
 
   async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
@@ -130,7 +127,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
     #print(layers[0])
 
-    return score, np.array(layers[0]['input_layernorm'])
+    return score, np.array(layers[0]['input_layernorm'], copy=False)
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
@@ -139,11 +136,7 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
     model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
 
     if self.shard != shard:
-
-      def load_shard_wrapper():
-        return asyncio.run(load_shard(model_path, shard))
-
-      model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
+      model_shard, self.tokenizer = await load_shard(model_path, shard)
       self.shard = shard
       self.model = model_shard 
       self.caches = OrderedDict()

+ 8 - 4
exo/main.py

@@ -150,9 +150,9 @@ api = ChatGPTAPI(
   on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
   default_model=args.default_model
 )
-node.on_token.register("update_topology_viz").on_next(
-  lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
-)
+# node.on_token.register("update_topology_viz").on_next(
+#   lambda req_id, token, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode([token])) if topology_viz and hasattr(inference_engine, "tokenizer") else None
+# )
 
 def preemptively_start_download(request_id: str, opaque_status: str):
   try:
@@ -200,7 +200,11 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
     print(f"Processing prompt: {prompt}")
     await node.process_prompt(shard, prompt, request_id=request_id)
 
-    _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
+    tokens = []
+    def on_token(_request_id, _token, _is_finished):
+      tokens.append(_token)
+      return _request_id == request_id and _is_finished
+    await callback.wait(on_token, timeout=300)
 
     print("\nGenerated response:")
     print(tokenizer.decode(tokens))

+ 4 - 14
exo/networking/grpc/grpc_peer_handle.py

@@ -71,7 +71,7 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
       return False
 
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
+  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None:
     request = node_service_pb2.PromptRequest(
       prompt=prompt,
       shard=node_service_pb2.Shard(
@@ -82,14 +82,9 @@ class GRPCPeerHandle(PeerHandle):
       ),
       request_id=request_id,
     )
-    response = await self.stub.SendPrompt(request)
+    await self.stub.SendPrompt(request)
 
-    if not response.tensor_data or not response.shape or not response.dtype:
-      return None
-
-    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
-
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
+  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None:
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
@@ -100,12 +95,7 @@ class GRPCPeerHandle(PeerHandle):
       tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
       request_id=request_id,
     )
-    response = await self.stub.SendTensor(request)
-
-    if not response.tensor_data or not response.shape or not response.dtype:
-      return None
-
-    return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
+    await self.stub.SendTensor(request)
   
   async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.ExampleRequest(

+ 6 - 9
exo/networking/grpc/grpc_server.py

@@ -50,10 +50,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     prompt = request.prompt
     request_id = request.request_id
-    result = await self.node.process_prompt(shard, prompt, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
-    tensor_data = result.tobytes() if result is not None else None
-    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+    await self.node.process_prompt(shard, prompt, request_id)
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=}")
+    return node_service_pb2.Empty()
 
   async def SendTensor(self, request, context):
     shard = Shard(
@@ -64,11 +63,9 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
-
-    result = await self.node.process_tensor(shard, tensor, request_id)
-    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
-    tensor_data = result.tobytes() if result is not None else None
-    return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
+    await self.node.process_tensor(shard, tensor, request_id)
+    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=}")
+    return node_service_pb2.Empty()
   
   async def SendExample(self, request, context):
     shard = Shard(

+ 2 - 2
exo/networking/grpc/node_service.proto

@@ -3,8 +3,8 @@ syntax = "proto3";
 package node_service;
 
 service NodeService {
-  rpc SendPrompt (PromptRequest) returns (Tensor) {}
-  rpc SendTensor (TensorRequest) returns (Tensor) {}
+  rpc SendPrompt (PromptRequest) returns (Empty) {}
+  rpc SendTensor (TensorRequest) returns (Empty) {}
   rpc SendExample (ExampleRequest) returns (Loss) {}
   rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
   rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}

Datei-Diff unterdrückt, da er zu groß ist
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 6 - 6
exo/networking/grpc/node_service_pb2_grpc.py

@@ -37,12 +37,12 @@ class NodeServiceStub(object):
         self.SendPrompt = channel.unary_unary(
                 '/node_service.NodeService/SendPrompt',
                 request_serializer=node__service__pb2.PromptRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
+                response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.SendTensor = channel.unary_unary(
                 '/node_service.NodeService/SendTensor',
                 request_serializer=node__service__pb2.TensorRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Tensor.FromString,
+                response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
@@ -122,12 +122,12 @@ def add_NodeServiceServicer_to_server(servicer, server):
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
                     servicer.SendPrompt,
                     request_deserializer=node__service__pb2.PromptRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             'SendTensor': grpc.unary_unary_rpc_method_handler(
                     servicer.SendTensor,
                     request_deserializer=node__service__pb2.TensorRequest.FromString,
-                    response_serializer=node__service__pb2.Tensor.SerializeToString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
@@ -181,7 +181,7 @@ class NodeService(object):
             target,
             '/node_service.NodeService/SendPrompt',
             node__service__pb2.PromptRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
+            node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,
@@ -208,7 +208,7 @@ class NodeService(object):
             target,
             '/node_service.NodeService/SendTensor',
             node__service__pb2.TensorRequest.SerializeToString,
-            node__service__pb2.Tensor.FromString,
+            node__service__pb2.Empty.FromString,
             options,
             channel_credentials,
             insecure,

+ 1 - 1
exo/networking/manual/manual_discovery.py

@@ -66,6 +66,6 @@ class ManualDiscovery(Discovery):
               pass
         except Exception as e:
           if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
-      await asyncio.sleep(1.0)
+      await asyncio.sleep(5.0)
 
       if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")

+ 23 - 49
exo/orchestration/node.py

@@ -107,6 +107,8 @@ class Node:
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
   
+  token_count = 0
+  first_token_time = 0
   async def process_inference_result(
     self,
     shard,
@@ -116,7 +118,14 @@ class Node:
     if request_id not in self.buffered_token_output:
       self.buffered_token_output[request_id] = ([], False)
     is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+    
     if shard.is_last_layer() and not is_finished:
+      self.token_count += 1
+      if self.token_count == 1:
+        self.first_token_time = time.perf_counter_ns()
+      if self.token_count % 20 == 0:
+        print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}")
+
       token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
       await self.inference_engine.ensure_shard(shard)
       self.buffered_token_output[request_id][0].append(token.item())
@@ -142,60 +151,29 @@ class Node:
     base_shard: Shard,
     prompt: str,
     request_id: Optional[str] = None,
-  ) -> Optional[np.ndarray]:
+  ) -> None:
     shard = self.get_current_shard(base_shard)
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "start_process_prompt",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "prompt": prompt,
-          "request_id": request_id,
-        }),
-      )
-    )
     start_time = time.perf_counter_ns()
-    resp = await self._process_prompt(base_shard, prompt, request_id)
+    await self._process_prompt(base_shard, prompt, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
-    asyncio.create_task(
-      self.broadcast_opaque_status(
-        request_id,
-        json.dumps({
-          "type": "node_status",
-          "node_id": self.id,
-          "status": "end_process_prompt",
-          "base_shard": base_shard.to_dict(),
-          "shard": shard.to_dict(),
-          "prompt": prompt,
-          "request_id": request_id,
-          "elapsed_time_ns": elapsed_time_ns,
-          "result_size": resp.size if resp is not None else 0,
-        }),
-      )
-    )
-    return resp
+    if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
 
   async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
-
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+
     if not shard.is_first_layer():
       if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
       self.outstanding_requests[request_id] = "waiting"
-      resp = await self.forward_prompt(shard, prompt, request_id, 0)
+      await self.forward_prompt(shard, prompt, request_id, 0)
       return None
-    else:
-      self.outstanding_requests[request_id] = "processing"
-      result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-      ret = await self.process_inference_result(shard, result, request_id)
-      return result
+
+    self.outstanding_requests[request_id] = "processing"
+    result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
+    await self.process_inference_result(shard, result, request_id)
 
   async def enqueue_example(
     self,
@@ -339,7 +317,7 @@ class Node:
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
-  ) -> Optional[np.ndarray]:
+  ) -> None:
     shard = self.get_current_shard(base_shard)
     asyncio.create_task(
       self.broadcast_opaque_status(
@@ -357,7 +335,7 @@ class Node:
       )
     )
     start_time = time.perf_counter_ns()
-    resp = await self._process_tensor(shard, tensor, request_id)
+    await self._process_tensor(shard, tensor, request_id)
     end_time = time.perf_counter_ns()
     elapsed_time_ns = end_time - start_time
     asyncio.create_task(
@@ -371,18 +349,16 @@ class Node:
           "shard": shard.to_dict(),
           "request_id": request_id,
           "elapsed_time_ns": elapsed_time_ns,
-          "result_size": resp.size if resp is not None else 0,
         }),
       )
     )
-    return resp
 
   async def _process_tensor(
     self,
     base_shard: Shard,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
-  ) -> Optional[np.ndarray]:
+  ) -> None:
     if request_id is None:
       request_id = str(uuid.uuid4())
     shard = self.get_current_shard(base_shard)
@@ -391,13 +367,11 @@ class Node:
     try:
       self.outstanding_requests[request_id] = "processing"
       result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
-      ret = await self.process_inference_result(shard, result, request_id) 
-      return ret
+      await self.process_inference_result(shard, result, request_id) 
     except Exception as e:
       self.outstanding_requests.pop(request_id)
       print(f"Error processing tensor for shard {shard}: {e}")
       traceback.print_exc()
-      return None
   
   async def forward_example(
     self,
@@ -621,7 +595,7 @@ class Node:
 
     await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
     # in the case of opaque status, we also want to receive our own opaque statuses
-    self.on_opaque_status.trigger_all(request_id, status)
+    # self.on_opaque_status.trigger_all(request_id, status)
 
   @property
   def current_topology(self) -> Topology:

Einige Dateien werden nicht angezeigt, da zu viele Dateien in diesem Diff geändert wurden.