Alex Cheema 1 год назад
Родитель
Сommit
db010d51fb

+ 13 - 19
exo/api/chatgpt_api.py

@@ -314,13 +314,13 @@ class ChatGPTAPI:
 
 
   async def handle_post_chat_completions(self, request):
   async def handle_post_chat_completions(self, request):
     data = await request.json()
     data = await request.json()
-    if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
+    if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     stream = data.get("stream", False)
     chat_request = parse_chat_request(data, self.default_model)
     chat_request = parse_chat_request(data, self.default_model)
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
     if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
       chat_request.model = self.default_model
       chat_request.model = self.default_model
     if not chat_request.model or chat_request.model not in model_cards:
     if not chat_request.model or chat_request.model not in model_cards:
-      if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
+      if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       chat_request.model = self.default_model
       chat_request.model = self.default_model
     shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     shard = build_base_shard(chat_request.model, self.inference_engine_classname)
     if not shard:
     if not shard:
@@ -331,7 +331,7 @@ class ChatGPTAPI:
       )
       )
 
 
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
     tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
-    if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
+    if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
 
 
     prompt = build_prompt(tokenizer, chat_request.messages)
     prompt = build_prompt(tokenizer, chat_request.messages)
     request_id = str(uuid.uuid4())
     request_id = str(uuid.uuid4())
@@ -340,25 +340,13 @@ class ChatGPTAPI:
         self.on_chat_completion_request(request_id, chat_request, prompt)
         self.on_chat_completion_request(request_id, chat_request, prompt)
       except Exception as e:
       except Exception as e:
         if DEBUG >= 2: traceback.print_exc()
         if DEBUG >= 2: traceback.print_exc()
-    # request_id = None
-    # match = self.prompts.find_longest_prefix(prompt)
-    # if match and len(prompt) > len(match[1].prompt):
-    #     if DEBUG >= 2:
-    #       print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
-    #     request_id = match[1].request_id
-    #     self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
-    #     # remove the matching prefix from the prompt
-    #     prompt = prompt[len(match[1].prompt):]
-    # else:
-    #   request_id = str(uuid.uuid4())
-    #   self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
-
-    if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
+
+    if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
 
 
     try:
     try:
       await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
       await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
 
 
-      if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
+      if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
 
 
       if stream:
       if stream:
         response = web.StreamResponse(
         response = web.StreamResponse(
@@ -374,10 +362,12 @@ class ChatGPTAPI:
         try:
         try:
           # Stream tokens while waiting for inference to complete
           # Stream tokens while waiting for inference to complete
           while True:
           while True:
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
             token, is_finished = await asyncio.wait_for(
             token, is_finished = await asyncio.wait_for(
               self.token_queues[request_id].get(),
               self.token_queues[request_id].get(),
               timeout=self.response_timeout
               timeout=self.response_timeout
             )
             )
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")
 
 
             finish_reason = None
             finish_reason = None
             eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
             eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
@@ -408,10 +398,13 @@ class ChatGPTAPI:
           return response
           return response
 
 
         except asyncio.TimeoutError:
         except asyncio.TimeoutError:
+          if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
           return web.json_response({"detail": "Response generation timed out"}, status=408)
           return web.json_response({"detail": "Response generation timed out"}, status=408)
 
 
         except Exception as e:
         except Exception as e:
-          if DEBUG >= 2: traceback.print_exc()
+          if DEBUG >= 2: 
+            print(f"[ChatGPTAPI] Error processing prompt: {e}")
+            traceback.print_exc()
           return web.json_response(
           return web.json_response(
             {"detail": f"Error processing prompt: {str(e)}"},
             {"detail": f"Error processing prompt: {str(e)}"},
             status=500
             status=500
@@ -420,6 +413,7 @@ class ChatGPTAPI:
         finally:
         finally:
           # Clean up the queue for this request
           # Clean up the queue for this request
           if request_id in self.token_queues:
           if request_id in self.token_queues:
+            if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
             del self.token_queues[request_id]
             del self.token_queues[request_id]
       else:
       else:
         tokens = []
         tokens = []

+ 1 - 1
exo/download/hf/hf_helpers.py

@@ -437,7 +437,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
       shard_specific_patterns.add(sorted_file_names[-1])
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
   else:
     shard_specific_patterns = set(["*.safetensors"])
     shard_specific_patterns = set(["*.safetensors"])
-  if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
+  if DEBUG >= 4: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
   return list(default_patterns | shard_specific_patterns)
 
 
 async def get_file_download_percentage(
 async def get_file_download_percentage(

+ 56 - 0
exo/main.py

@@ -38,6 +38,7 @@ import concurrent.futures
 import socket
 import socket
 import resource
 import resource
 import psutil
 import psutil
+import grpc
 
 
 # Configure uvloop for maximum performance
 # Configure uvloop for maximum performance
 def configure_uvloop():
 def configure_uvloop():
@@ -308,6 +309,61 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
 async def main():
 async def main():
   loop = asyncio.get_running_loop()
   loop = asyncio.get_running_loop()
 
 
+  # Set up OpenTelemetry
+  from opentelemetry import trace
+  from opentelemetry.sdk.trace import TracerProvider
+  from opentelemetry.sdk.trace.export import BatchSpanProcessor, SimpleSpanProcessor
+  from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
+  from opentelemetry.sdk.resources import Resource
+  
+  # Check if Jaeger is available
+  def check_jaeger_connection():
+    try:
+      # Try to connect to the OTLP gRPC port
+      sock = socket.create_connection(("localhost", 4317), timeout=1)
+      sock.close()
+      return True
+    except (socket.timeout, socket.error):
+      return False
+  
+  # Create and configure the tracer
+  resource = Resource.create({
+    "service.name": "exo-distributed",
+    "service.instance.id": args.node_id
+  })
+  
+  tracer_provider = TracerProvider(resource=resource)
+  
+  if check_jaeger_connection():
+    print("Jaeger connection successful, setting up tracing...")
+    # Configure the OTLP exporter with better defaults for high throughput
+    otlp_exporter = OTLPSpanExporter(
+      endpoint="http://localhost:4317",
+      # Increase timeout to handle larger batches
+      timeout=30.0,
+    )
+    
+    # Configure the BatchSpanProcessor with appropriate batch settings
+    span_processor = BatchSpanProcessor(
+      otlp_exporter,
+      # Reduce export frequency
+      schedule_delay_millis=5000,
+      # Increase max batch size
+      max_export_batch_size=512,
+      # Limit queue size to prevent memory issues
+      max_queue_size=2048,
+    )
+    
+    tracer_provider.add_span_processor(span_processor)
+  else:
+    print("Warning: Could not connect to Jaeger, tracing will be disabled")
+    # Use a no-op span processor if Jaeger is not available
+    from opentelemetry.sdk.trace.export import ConsoleSpanExporter
+    tracer_provider.add_span_processor(SimpleSpanProcessor(ConsoleSpanExporter()))
+  
+  # Set the tracer provider
+  trace.set_tracer_provider(tracer_provider)
+
   # Check HuggingFace directory permissions
   # Check HuggingFace directory permissions
   hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
   hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
   if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
   if DEBUG >= 1: print(f"Model storage directory: {hf_home}")

+ 111 - 26
exo/networking/grpc/grpc_peer_handle.py

@@ -90,34 +90,66 @@ class GRPCPeerHandle(PeerHandle):
         traceback.print_exc()
         traceback.print_exc()
       return False
       return False
 
 
-  async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> None:
-    request = node_service_pb2.PromptRequest(
-      prompt=prompt,
+  async def send_prompt(
+    self,
+    shard: Shard,
+    prompt: str,
+    request_id: Optional[str] = None,
+    sequence_number: Optional[int] = None,
+    trace_parent: Optional[str] = None
+  ) -> None:
+    request = node_service_pb2.SendPromptRequest(
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         model_id=shard.model_id,
         start_layer=shard.start_layer,
         start_layer=shard.start_layer,
         end_layer=shard.end_layer,
         end_layer=shard.end_layer,
         n_layers=shard.n_layers,
         n_layers=shard.n_layers,
       ),
       ),
+      prompt=prompt,
       request_id=request_id,
       request_id=request_id,
+      sequence_number=sequence_number,
+      trace_parent=trace_parent
     )
     )
     await self.stub.SendPrompt(request)
     await self.stub.SendPrompt(request)
 
 
-  async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> None:
-    request = node_service_pb2.TensorRequest(
+  async def send_tensor(
+    self,
+    shard: Shard,
+    tensor: np.ndarray,
+    request_id: Optional[str] = None,
+    sequence_number: Optional[int] = None,
+    trace_parent: Optional[str] = None
+  ) -> None:
+    request = node_service_pb2.SendTensorRequest(
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         model_id=shard.model_id,
         start_layer=shard.start_layer,
         start_layer=shard.start_layer,
         end_layer=shard.end_layer,
         end_layer=shard.end_layer,
         n_layers=shard.n_layers,
         n_layers=shard.n_layers,
       ),
       ),
-      tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
+      tensor=node_service_pb2.Tensor(
+        tensor_data=tensor.tobytes(),
+        shape=tensor.shape,
+        dtype=str(tensor.dtype)
+      ),
       request_id=request_id,
       request_id=request_id,
+      sequence_number=sequence_number,
+      trace_parent=trace_parent
     )
     )
     await self.stub.SendTensor(request)
     await self.stub.SendTensor(request)
-  
-  async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
-    request = node_service_pb2.ExampleRequest(
+
+  async def send_example(
+    self,
+    shard: Shard,
+    example: np.ndarray,
+    target: np.ndarray,
+    length: np.ndarray,
+    train: bool,
+    request_id: Optional[str] = None,
+    sequence_number: Optional[int] = None,
+    trace_parent: Optional[str] = None
+  ) -> Optional[np.array]:
+    request = node_service_pb2.SendExampleRequest(
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
         model_id=shard.model_id,
         model_id=shard.model_id,
         start_layer=shard.start_layer,
         start_layer=shard.start_layer,
@@ -129,6 +161,8 @@ class GRPCPeerHandle(PeerHandle):
       length=node_service_pb2.Tensor(tensor_data=length.tobytes(), shape=length.shape, dtype=str(length.dtype)),
       length=node_service_pb2.Tensor(tensor_data=length.tobytes(), shape=length.shape, dtype=str(length.dtype)),
       train=train,
       train=train,
       request_id=request_id,
       request_id=request_id,
+      sequence_number=sequence_number,
+      trace_parent=trace_parent
     )
     )
     response = await self.stub.SendExample(request)
     response = await self.stub.SendExample(request)
     loss = response.loss
     loss = response.loss
@@ -137,7 +171,7 @@ class GRPCPeerHandle(PeerHandle):
       return loss, grads
       return loss, grads
     else:
     else:
       return loss
       return loss
-  
+
   async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
   async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
     request = node_service_pb2.TensorRequest(
     request = node_service_pb2.TensorRequest(
       shard=node_service_pb2.Shard(
       shard=node_service_pb2.Shard(
@@ -156,27 +190,78 @@ class GRPCPeerHandle(PeerHandle):
 
 
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
     return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
 
 
-  async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
-    request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
+  async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
+    if DEBUG >= 2: print(f"[GRPCPeerHandle] Collecting topology from {self.id()} with {visited=} {max_depth=}")
+    
+    # Convert set to list for GRPC request
+    request = node_service_pb2.CollectTopologyRequest(
+      visited=list(visited),
+      max_depth=max_depth
+    )
+    
+    # Make GRPC call
     response = await self.stub.CollectTopology(request)
     response = await self.stub.CollectTopology(request)
+    if DEBUG >= 2: print(f"[GRPCPeerHandle] Got topology response from {self.id()}")
+    
+    # Convert proto topology to Topology object
     topology = Topology()
     topology = Topology()
-    for node_id, capabilities in response.nodes.items():
-      device_capabilities = DeviceCapabilities(
-        model=capabilities.model,
-        chip=capabilities.chip,
-        memory=capabilities.memory,
-        flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
+    proto_topology = response.topology
+    
+    # Convert nodes and their capabilities
+    for node in proto_topology.nodes:
+      # Convert DeviceCapabilities
+      flops = DeviceFlops(
+        fp32=node.capabilities.flops.fp32,
+        fp16=node.capabilities.flops.fp16,
+        int8=node.capabilities.flops.int8
+      )
+      capabilities = DeviceCapabilities(
+        model=node.capabilities.model,
+        chip=node.capabilities.chip,
+        memory=node.capabilities.memory,
+        flops=flops
       )
       )
-      topology.update_node(node_id, device_capabilities)
-    for node_id, peer_connections in response.peer_graph.items():
-      for conn in peer_connections.connections:
-        topology.add_edge(node_id, conn.to_id, conn.description)
+      
+      # Add node to topology
+      topology.update_node(node.id, capabilities)
+      
+      # Add connections
+      for conn in node.connections:
+        topology.add_edge(node.id, conn.to_id, conn.description if conn.HasField("description") else None)
+    
+    # Set active node
+    if proto_topology.HasField("active_node_id"):
+      topology.active_node_id = proto_topology.active_node_id
+    
+    if DEBUG >= 2: print(f"[GRPCPeerHandle] Converted topology from {self.id()} with {len(topology.nodes)} nodes")
     return topology
     return topology
 
 
-  async def send_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
-    request = node_service_pb2.SendNewTokenRequest(request_id=request_id, token=token, is_finished=is_finished)
+  async def send_new_token(
+    self,
+    request_id: str,
+    token: int,
+    is_finished: bool,
+    sequence_number: Optional[int] = None,
+    trace_parent: Optional[str] = None
+  ) -> None:
+    request = node_service_pb2.SendNewTokenRequest(
+      request_id=request_id,
+      token=token,
+      is_finished=is_finished,
+      sequence_number=sequence_number,
+      trace_parent=trace_parent
+    )
     await self.stub.SendNewToken(request)
     await self.stub.SendNewToken(request)
 
 
-  async def send_opaque_status(self, request_id: str, status: str) -> None:
-    request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
+  async def send_opaque_status(
+    self,
+    request_id: str,
+    status: str,
+    trace_parent: Optional[str] = None
+  ) -> None:
+    request = node_service_pb2.SendOpaqueStatusRequest(
+      request_id=request_id,
+      status=status,
+      trace_parent=trace_parent
+    )
     await self.stub.SendOpaqueStatus(request)
     await self.stub.SendOpaqueStatus(request)

+ 121 - 26
exo/networking/grpc/grpc_server.py

@@ -58,8 +58,21 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     )
     prompt = request.prompt
     prompt = request.prompt
     request_id = request.request_id
     request_id = request.request_id
+    sequence_number = request.sequence_number if hasattr(request, 'sequence_number') else None
+    trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None
+    
+    # Update trace context if sequence number or trace parent is provided
+    if sequence_number is not None or trace_parent is not None:
+      from exo.orchestration.tracing import tracer, TraceContext
+      context = TraceContext(
+        request_id=request_id,
+        sequence_number=sequence_number or 0,
+        trace_parent=trace_parent
+      )
+      tracer.set_context(request_id, context)
+    
     await self.node.process_prompt(shard, prompt, request_id)
     await self.node.process_prompt(shard, prompt, request_id)
-    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=}")
+    if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} {sequence_number=}")
     return node_service_pb2.Empty()
     return node_service_pb2.Empty()
 
 
   async def SendTensor(self, request, context):
   async def SendTensor(self, request, context):
@@ -71,8 +84,21 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     )
     )
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
     request_id = request.request_id
     request_id = request.request_id
+    sequence_number = request.sequence_number if hasattr(request, 'sequence_number') else None
+    trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None
+    
+    # Update trace context if sequence number or trace parent is provided
+    if sequence_number is not None or trace_parent is not None:
+      from exo.orchestration.tracing import tracer, TraceContext
+      context = TraceContext(
+        request_id=request_id,
+        sequence_number=sequence_number or 0,
+        trace_parent=trace_parent
+      )
+      tracer.set_context(request_id, context)
+    
     await self.node.process_tensor(shard, tensor, request_id)
     await self.node.process_tensor(shard, tensor, request_id)
-    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=}")
+    if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} {sequence_number=}")
     return node_service_pb2.Empty()
     return node_service_pb2.Empty()
   
   
   async def SendExample(self, request, context):
   async def SendExample(self, request, context):
@@ -87,6 +113,18 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
     length = np.frombuffer(request.length.tensor_data, dtype=np.dtype(request.length.dtype)).reshape(request.length.shape)
     length = np.frombuffer(request.length.tensor_data, dtype=np.dtype(request.length.dtype)).reshape(request.length.shape)
     train = request.train
     train = request.train
     request_id = request.request_id
     request_id = request.request_id
+    sequence_number = request.sequence_number if hasattr(request, 'sequence_number') else None
+    trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None
+    
+    # Update trace context if sequence number or trace parent is provided
+    if sequence_number is not None or trace_parent is not None:
+      from exo.orchestration.tracing import tracer, TraceContext
+      context = TraceContext(
+        request_id=request_id,
+        sequence_number=sequence_number or 0,
+        trace_parent=trace_parent
+      )
+      tracer.set_context(request_id, context)
 
 
     if train and not shard.is_first_layer():
     if train and not shard.is_first_layer():
       loss, grad = await self.node.process_example(shard, example, target, length, train, request_id)
       loss, grad = await self.node.process_example(shard, example, target, length, train, request_id)
@@ -97,43 +135,100 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
       loss = await self.node.process_example(shard, example, target, length, train, request_id)
       loss = await self.node.process_example(shard, example, target, length, train, request_id)
       return node_service_pb2.Loss(loss=loss, grads=None)
       return node_service_pb2.Loss(loss=loss, grads=None)
     
     
-  async def CollectTopology(self, request, context):
-    max_depth = request.max_depth
+  async def CollectTopology(
+    self,
+    request: node_service_pb2.CollectTopologyRequest,
+    context: grpc.aio.ServicerContext,
+  ) -> node_service_pb2.CollectTopologyResponse:
+    # Convert visited list back to set
     visited = set(request.visited)
     visited = set(request.visited)
-    topology = self.node.current_topology
-    nodes = {
-      node_id:
-        node_service_pb2.DeviceCapabilities(
-          model=cap.model,
-          chip=cap.chip,
-          memory=cap.memory,
-          flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
-        )
-      for node_id, cap in topology.nodes.items()
-    }
-    peer_graph = {
-      node_id: node_service_pb2.PeerConnections(
-        connections=[
-          node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
-          for conn in connections
-        ]
+    if DEBUG >= 2: print(f"[GRPCServer] CollectTopology request with {visited=} {request.max_depth=}")
+    
+    # Get topology from node
+    topology = await self.node.collect_topology(visited, request.max_depth)
+    if DEBUG >= 2: print(f"[GRPCServer] Got topology: {topology}")
+    
+    # Convert Topology to proto message
+    proto_topology = node_service_pb2.CollectTopologyResponse.Topology()
+    
+    # Convert nodes and their capabilities
+    for node_id, capabilities in topology.nodes.items():
+      # Create DeviceFlops
+      flops = node_service_pb2.CollectTopologyResponse.DeviceFlops(
+        fp32=capabilities.flops.fp32,
+        fp16=capabilities.flops.fp16,
+        int8=capabilities.flops.int8
       )
       )
-      for node_id, connections in topology.peer_graph.items()
-    }
-    if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
-    return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
+      
+      # Create DeviceCapabilities
+      device_caps = node_service_pb2.CollectTopologyResponse.DeviceCapabilities(
+        model=capabilities.model,
+        chip=capabilities.chip,
+        memory=capabilities.memory,
+        flops=flops
+      )
+      
+      # Get connections for this node
+      connections = []
+      if node_id in topology.peer_graph:
+        for conn in topology.peer_graph[node_id]:
+          proto_conn = node_service_pb2.CollectTopologyResponse.PeerConnection(
+            to_id=conn.to_id,
+            description=conn.description if conn.description else None
+          )
+          connections.append(proto_conn)
+      
+      # Create Node with its connections
+      node = node_service_pb2.CollectTopologyResponse.Node(
+        id=node_id,
+        capabilities=device_caps,
+        connections=connections
+      )
+      proto_topology.nodes.append(node)
+    
+    # Set active node if present
+    if topology.active_node_id:
+      proto_topology.active_node_id = topology.active_node_id
+    
+    if DEBUG >= 2: print(f"[GRPCServer] Sending topology response with {len(proto_topology.nodes)} nodes")
+    return node_service_pb2.CollectTopologyResponse(topology=proto_topology)
 
 
   async def SendNewToken(self, request, context):
   async def SendNewToken(self, request, context):
     request_id = request.request_id
     request_id = request.request_id
     token = request.token
     token = request.token
     is_finished = request.is_finished
     is_finished = request.is_finished
-    if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=}")
+    sequence_number = request.sequence_number if hasattr(request, 'sequence_number') else None
+    trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None
+    
+    # Update trace context if sequence number or trace parent is provided
+    if sequence_number is not None or trace_parent is not None:
+      from exo.orchestration.tracing import tracer, TraceContext
+      context = TraceContext(
+        request_id=request_id,
+        sequence_number=sequence_number or 0,
+        trace_parent=trace_parent
+      )
+      tracer.set_context(request_id, context)
+    
+    if DEBUG >= 5: print(f"Received SendNewToken request: {request_id=} {token=} {is_finished=} {sequence_number=}")
     self.node.on_token.trigger_all(request_id, token, is_finished)
     self.node.on_token.trigger_all(request_id, token, is_finished)
     return node_service_pb2.Empty()
     return node_service_pb2.Empty()
 
 
   async def SendOpaqueStatus(self, request, context):
   async def SendOpaqueStatus(self, request, context):
     request_id = request.request_id
     request_id = request.request_id
     status = request.status
     status = request.status
+    trace_parent = request.trace_parent if hasattr(request, 'trace_parent') else None
+    
+    # Update trace context if trace parent is provided
+    if trace_parent is not None:
+      from exo.orchestration.tracing import tracer, TraceContext
+      context = TraceContext(
+        request_id=request_id,
+        sequence_number=0,
+        trace_parent=trace_parent
+      )
+      tracer.set_context(request_id, context)
+    
     if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
     if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
     self.node.on_opaque_status.trigger_all(request_id, status)
     self.node.on_opaque_status.trigger_all(request_id, status)
     return node_service_pb2.Empty()
     return node_service_pb2.Empty()

+ 49 - 35
exo/networking/grpc/node_service.proto

@@ -3,10 +3,10 @@ syntax = "proto3";
 package node_service;
 package node_service;
 
 
 service NodeService {
 service NodeService {
-  rpc SendPrompt (PromptRequest) returns (Empty) {}
-  rpc SendTensor (TensorRequest) returns (Empty) {}
-  rpc SendExample (ExampleRequest) returns (Loss) {}
-  rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
+  rpc SendPrompt (SendPromptRequest) returns (Empty) {}
+  rpc SendTensor (SendTensorRequest) returns (Empty) {}
+  rpc SendExample (SendExampleRequest) returns (Empty) {}
+  rpc CollectTopology (CollectTopologyRequest) returns (CollectTopologyResponse) {}
   rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}
   rpc SendNewToken (SendNewTokenRequest) returns (Empty) {}
   rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
   rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
   rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
   rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
@@ -19,25 +19,30 @@ message Shard {
   int32 n_layers = 4;
   int32 n_layers = 4;
 }
 }
 
 
-message PromptRequest {
+message SendPromptRequest {
   Shard shard = 1;
   Shard shard = 1;
   string prompt = 2;
   string prompt = 2;
-  optional string request_id = 3;
+  string request_id = 3;
+  int32 sequence_number = 4;
+  string trace_parent = 5;
 }
 }
 
 
-message TensorRequest {
+message SendTensorRequest {
   Shard shard = 1;
   Shard shard = 1;
   Tensor tensor = 2;
   Tensor tensor = 2;
-  optional string request_id = 3;
+  string request_id = 3;
+  int32 sequence_number = 4;
+  string trace_parent = 5;
 }
 }
 
 
-message ExampleRequest {
+message SendExampleRequest {
   Shard shard = 1;
   Shard shard = 1;
-  Tensor example = 2;
-  Tensor target = 3;
-  Tensor length = 4;
-  bool train = 5;
-  optional string request_id = 6;
+  bytes example = 2;
+  bytes target = 3;
+  bytes length = 4;
+  string request_id = 5;
+  bool train = 6;
+  string trace_parent = 7;
 }
 }
 
 
 message Loss {
 message Loss {
@@ -56,42 +61,51 @@ message CollectTopologyRequest {
   int32 max_depth = 2;
   int32 max_depth = 2;
 }
 }
 
 
-message Topology {
-  map<string, DeviceCapabilities> nodes = 1;
-  map<string, PeerConnections> peer_graph = 2;
-}
+message CollectTopologyResponse {
+  message DeviceFlops {
+    double fp32 = 1;
+    double fp16 = 2;
+    double int8 = 3;
+  }
 
 
-message PeerConnection {
-  string to_id = 1;
-  optional string description = 2;
-}
+  message DeviceCapabilities {
+    string model = 1;
+    string chip = 2;
+    int32 memory = 3;
+    DeviceFlops flops = 4;
+  }
 
 
-message PeerConnections {
-  repeated PeerConnection connections = 1;
-}
+  message PeerConnection {
+    string to_id = 1;
+    optional string description = 2;
+  }
 
 
-message DeviceFlops {
-  double fp32 = 1;
-  double fp16 = 2;
-  double int8 = 3;
-}
+  message Node {
+    string id = 1;
+    DeviceCapabilities capabilities = 2;
+    repeated PeerConnection connections = 3;
+  }
+
+  message Topology {
+    repeated Node nodes = 1;
+    optional string active_node_id = 2;
+  }
 
 
-message DeviceCapabilities {
-  string model = 1;
-  string chip = 2;
-  int32 memory = 3;
-  DeviceFlops flops = 4;
+  Topology topology = 1;
 }
 }
 
 
 message SendNewTokenRequest {
 message SendNewTokenRequest {
   string request_id = 1;
   string request_id = 1;
   int32 token = 2;
   int32 token = 2;
   bool is_finished = 3;
   bool is_finished = 3;
+  int32 sequence_number = 4;
+  string trace_parent = 5;
 }
 }
 
 
 message SendOpaqueStatusRequest {
 message SendOpaqueStatusRequest {
   string request_id = 1;
   string request_id = 1;
   string status = 2;
   string status = 2;
+  string trace_parent = 3;
 }
 }
 
 
 message HealthCheckRequest {}
 message HealthCheckRequest {}

Разница между файлами не показана из-за своего большого размера
+ 0 - 0
exo/networking/grpc/node_service_pb2.py


+ 15 - 15
exo/networking/grpc/node_service_pb2_grpc.py

@@ -36,23 +36,23 @@ class NodeServiceStub(object):
         """
         """
         self.SendPrompt = channel.unary_unary(
         self.SendPrompt = channel.unary_unary(
                 '/node_service.NodeService/SendPrompt',
                 '/node_service.NodeService/SendPrompt',
-                request_serializer=node__service__pb2.PromptRequest.SerializeToString,
+                request_serializer=node__service__pb2.SendPromptRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
                 _registered_method=True)
         self.SendTensor = channel.unary_unary(
         self.SendTensor = channel.unary_unary(
                 '/node_service.NodeService/SendTensor',
                 '/node_service.NodeService/SendTensor',
-                request_serializer=node__service__pb2.TensorRequest.SerializeToString,
+                request_serializer=node__service__pb2.SendTensorRequest.SerializeToString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
                 _registered_method=True)
         self.SendExample = channel.unary_unary(
         self.SendExample = channel.unary_unary(
                 '/node_service.NodeService/SendExample',
                 '/node_service.NodeService/SendExample',
-                request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Loss.FromString,
+                request_serializer=node__service__pb2.SendExampleRequest.SerializeToString,
+                response_deserializer=node__service__pb2.Empty.FromString,
                 _registered_method=True)
                 _registered_method=True)
         self.CollectTopology = channel.unary_unary(
         self.CollectTopology = channel.unary_unary(
                 '/node_service.NodeService/CollectTopology',
                 '/node_service.NodeService/CollectTopology',
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
                 request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
-                response_deserializer=node__service__pb2.Topology.FromString,
+                response_deserializer=node__service__pb2.CollectTopologyResponse.FromString,
                 _registered_method=True)
                 _registered_method=True)
         self.SendNewToken = channel.unary_unary(
         self.SendNewToken = channel.unary_unary(
                 '/node_service.NodeService/SendNewToken',
                 '/node_service.NodeService/SendNewToken',
@@ -121,23 +121,23 @@ def add_NodeServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
     rpc_method_handlers = {
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
             'SendPrompt': grpc.unary_unary_rpc_method_handler(
                     servicer.SendPrompt,
                     servicer.SendPrompt,
-                    request_deserializer=node__service__pb2.PromptRequest.FromString,
+                    request_deserializer=node__service__pb2.SendPromptRequest.FromString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             ),
             'SendTensor': grpc.unary_unary_rpc_method_handler(
             'SendTensor': grpc.unary_unary_rpc_method_handler(
                     servicer.SendTensor,
                     servicer.SendTensor,
-                    request_deserializer=node__service__pb2.TensorRequest.FromString,
+                    request_deserializer=node__service__pb2.SendTensorRequest.FromString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
                     response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             ),
             'SendExample': grpc.unary_unary_rpc_method_handler(
             'SendExample': grpc.unary_unary_rpc_method_handler(
                     servicer.SendExample,
                     servicer.SendExample,
-                    request_deserializer=node__service__pb2.ExampleRequest.FromString,
-                    response_serializer=node__service__pb2.Loss.SerializeToString,
+                    request_deserializer=node__service__pb2.SendExampleRequest.FromString,
+                    response_serializer=node__service__pb2.Empty.SerializeToString,
             ),
             ),
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
             'CollectTopology': grpc.unary_unary_rpc_method_handler(
                     servicer.CollectTopology,
                     servicer.CollectTopology,
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
                     request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
-                    response_serializer=node__service__pb2.Topology.SerializeToString,
+                    response_serializer=node__service__pb2.CollectTopologyResponse.SerializeToString,
             ),
             ),
             'SendNewToken': grpc.unary_unary_rpc_method_handler(
             'SendNewToken': grpc.unary_unary_rpc_method_handler(
                     servicer.SendNewToken,
                     servicer.SendNewToken,
@@ -180,7 +180,7 @@ class NodeService(object):
             request,
             request,
             target,
             target,
             '/node_service.NodeService/SendPrompt',
             '/node_service.NodeService/SendPrompt',
-            node__service__pb2.PromptRequest.SerializeToString,
+            node__service__pb2.SendPromptRequest.SerializeToString,
             node__service__pb2.Empty.FromString,
             node__service__pb2.Empty.FromString,
             options,
             options,
             channel_credentials,
             channel_credentials,
@@ -207,7 +207,7 @@ class NodeService(object):
             request,
             request,
             target,
             target,
             '/node_service.NodeService/SendTensor',
             '/node_service.NodeService/SendTensor',
-            node__service__pb2.TensorRequest.SerializeToString,
+            node__service__pb2.SendTensorRequest.SerializeToString,
             node__service__pb2.Empty.FromString,
             node__service__pb2.Empty.FromString,
             options,
             options,
             channel_credentials,
             channel_credentials,
@@ -234,8 +234,8 @@ class NodeService(object):
             request,
             request,
             target,
             target,
             '/node_service.NodeService/SendExample',
             '/node_service.NodeService/SendExample',
-            node__service__pb2.ExampleRequest.SerializeToString,
-            node__service__pb2.Loss.FromString,
+            node__service__pb2.SendExampleRequest.SerializeToString,
+            node__service__pb2.Empty.FromString,
             options,
             options,
             channel_credentials,
             channel_credentials,
             insecure,
             insecure,
@@ -262,7 +262,7 @@ class NodeService(object):
             target,
             target,
             '/node_service.NodeService/CollectTopology',
             '/node_service.NodeService/CollectTopology',
             node__service__pb2.CollectTopologyRequest.SerializeToString,
             node__service__pb2.CollectTopologyRequest.SerializeToString,
-            node__service__pb2.Topology.FromString,
+            node__service__pb2.CollectTopologyResponse.FromString,
             options,
             options,
             channel_credentials,
             channel_credentials,
             insecure,
             insecure,

+ 218 - 98
exo/orchestration/node.py

@@ -16,6 +16,7 @@ from exo.viz.topology_viz import TopologyViz
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.download.hf.hf_helpers import RepoProgressEvent
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.inference.inference_engine import get_inference_engine, InferenceEngine
 from exo.download.hf.hf_shard_download import HFShardDownloader
 from exo.download.hf.hf_shard_download import HFShardDownloader
+from exo.orchestration.tracing import tracer, TraceContext
 
 
 class Node:
 class Node:
   def __init__(
   def __init__(
@@ -111,44 +112,79 @@ class Node:
   def get_topology_inference_engines(self) -> List[List[str]]:
   def get_topology_inference_engines(self) -> List[List[str]]:
     return self.topology_inference_engines_pool
     return self.topology_inference_engines_pool
   
   
-  token_count = 0
-  first_token_time = 0
   async def process_inference_result(
   async def process_inference_result(
     self,
     self,
     shard,
     shard,
     result: np.ndarray,
     result: np.ndarray,
     request_id: Optional[str] = None,
     request_id: Optional[str] = None,
   ):
   ):
-    if request_id not in self.buffered_token_output:
-      self.buffered_token_output[request_id] = ([], False)
-    is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-    
-    if shard.is_last_layer() and not is_finished:
-      self.token_count += 1
-      if self.token_count == 1:
-        self.first_token_time = time.perf_counter_ns()
-      if self.token_count % 20 == 0:
-        print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}")
-
-      token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
-      await self.inference_engine.ensure_shard(shard)
-      self.buffered_token_output[request_id][0].append(token.item())
-      is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
-      if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
-      forward = token.reshape(1, -1)
-      self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
-      asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished))
-    else:
-      forward = result
-
-    if is_finished:
-      self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
-      self.outstanding_requests.pop(request_id)
-    else:
-      self.outstanding_requests[request_id] = "waiting"
-      asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+    context = tracer.get_context(request_id)
+    if not context:
+      context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0)
+      tracer.set_context(request_id, context)
 
 
-    return np.array(self.buffered_token_output[request_id][0])
+    try:
+      with tracer.start_span(
+        f"process_inference_result.{self.get_partition_index()}",
+        context,
+        extra_attributes={
+          "partition_index": self.get_partition_index(),
+          "node_id": self.id,
+          "start_layer": shard.start_layer,
+          "end_layer": shard.end_layer
+        }
+      ):
+        if request_id not in self.buffered_token_output:
+          self.buffered_token_output[request_id] = ([], False)
+        is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+        
+        if shard.is_last_layer() and not is_finished:
+          token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
+          forward = token.reshape(1, -1)
+          
+          # Increment sequence number for next forward pass
+          next_sequence = context.sequence_number + 1
+          # Create new context but preserve request span
+          next_context = TraceContext(
+            request_id=context.request_id, 
+            sequence_number=next_sequence,
+            request_span=context.request_span  # Preserve request span
+          )
+          tracer.set_context(request_id, next_context)
+          
+          self.buffered_token_output[request_id][0].append(token.item())
+          is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished
+          self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
+          await self.broadcast_new_token(request_id, token.item(), is_finished)
+          
+          if not is_finished:
+            self.outstanding_requests[request_id] = "waiting"
+            asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+        else:
+          forward = result
+          if not is_finished:
+            self.outstanding_requests[request_id] = "waiting"
+            asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+
+        if is_finished:
+          # End the request span when generation is complete
+          if context.request_span:
+            context.request_span.set_attribute("total_tokens", len(self.buffered_token_output[request_id][0]))
+            context.request_span.end()
+            context.request_span = None
+          self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+          self.outstanding_requests.pop(request_id)
+
+        return np.array(self.buffered_token_output[request_id][0])
+    except Exception as e:
+      if request_id in self.outstanding_requests:
+        self.outstanding_requests.pop(request_id)
+      # End request span on error
+      if context and context.request_span:
+        context.request_span.set_status(Status(StatusCode.ERROR, str(e)))
+        context.request_span.end()
+        context.request_span = None
+      raise
 
 
   async def process_prompt(
   async def process_prompt(
     self,
     self,
@@ -195,18 +231,46 @@ class Node:
   async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
   async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
     if request_id is None:
     if request_id is None:
       request_id = str(uuid.uuid4())
       request_id = str(uuid.uuid4())
+      
+    # Create or get trace context
+    context = tracer.get_context(request_id)
+    if not context:
+      # Create new context with request span
+      request_span = tracer.tracer.start_span(
+        "request",
+        attributes={
+          "request_id": request_id,
+          "prompt": prompt,
+          "node_id": self.id,
+          "request_type": "process_prompt"
+        }
+      )
+      context = TraceContext(
+        request_id=request_id,
+        sequence_number=0,
+        request_span=request_span,
+        current_span=request_span,
+        trace_parent=tracer.inject_context(request_span)
+      )
+      tracer.set_context(request_id, context)
+      
     shard = self.get_current_shard(base_shard)
     shard = self.get_current_shard(base_shard)
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
     if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
 
 
-    if not shard.is_first_layer():
-      if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
-      self.outstanding_requests[request_id] = "waiting"
-      await self.forward_prompt(shard, prompt, request_id, 0)
-      return None
+    try:
+      if not shard.is_first_layer():
+        if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+        self.outstanding_requests[request_id] = "waiting"
+        await self.forward_prompt(shard, prompt, request_id, 0)
+        return None
 
 
-    self.outstanding_requests[request_id] = "processing"
-    result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
-    await self.process_inference_result(shard, result, request_id)
+      self.outstanding_requests[request_id] = "processing"
+      result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
+      await self.process_inference_result(shard, result, request_id)
+    except Exception as e:
+      if context.request_span:
+        context.request_span.set_status(Status(StatusCode.ERROR, str(e)))
+      raise
 
 
   async def enqueue_example(
   async def enqueue_example(
     self,
     self,
@@ -350,33 +414,36 @@ class Node:
     base_shard: Shard,
     base_shard: Shard,
     tensor: np.ndarray,
     tensor: np.ndarray,
     request_id: Optional[str] = None,
     request_id: Optional[str] = None,
-  ) -> None:
-    shard = self.get_current_shard(base_shard)
-    start_time = time.perf_counter_ns()
-    await self._process_tensor(shard, tensor, request_id)
-    end_time = time.perf_counter_ns()
-    elapsed_time_ns = end_time - start_time
-    if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
-
-  async def _process_tensor(
-    self,
-    base_shard: Shard,
-    tensor: np.ndarray,
-    request_id: Optional[str] = None,
-  ) -> None:
-    if request_id is None:
-      request_id = str(uuid.uuid4())
-    shard = self.get_current_shard(base_shard)
+  ):
+    context = tracer.get_context(request_id)
+    if not context:
+      context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0)
+      tracer.set_context(request_id, context)
 
 
     try:
     try:
       self.outstanding_requests[request_id] = "processing"
       self.outstanding_requests[request_id] = "processing"
-      result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
-      await self.process_inference_result(shard, result, request_id) 
+      with tracer.start_span(
+        f"process_tensor.{self.get_partition_index()}",
+        context,
+        extra_attributes={
+          "partition_index": self.get_partition_index(),
+          "node_id": self.id,
+          "start_layer": base_shard.start_layer,
+          "end_layer": base_shard.end_layer,
+          "tensor_shape": str(tensor.shape)
+        }
+      ):
+        result = await self.inference_engine.infer_tensor(request_id, base_shard, tensor)
+        await self.process_inference_result(base_shard, result, request_id)
     except Exception as e:
     except Exception as e:
-      self.outstanding_requests.pop(request_id)
-      print(f"Error processing tensor for shard {shard}: {e}")
+      if request_id in self.outstanding_requests:
+        self.outstanding_requests.pop(request_id)
+      if context and context.request_span:
+        context.request_span.set_status(Status(StatusCode.ERROR, str(e)))
+      print(f"Error processing tensor for shard {base_shard}: {e}")
       traceback.print_exc()
       traceback.print_exc()
-  
+      raise
+
   async def forward_example(
   async def forward_example(
     self,
     self,
     base_shard: Shard,
     base_shard: Shard,
@@ -405,18 +472,39 @@ class Node:
     request_id: str,
     request_id: str,
     target_index: int,
     target_index: int,
   ) -> None:
   ) -> None:
-    if DEBUG >= 1: print(f"target partition index: {target_index}")
-    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
-    next_shard = self.get_current_shard(base_shard, target_index)
-    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
-    if target_id == self.id:
-      await self.process_prompt(next_shard, prompt, request_id)
-    else:
-      target_peer = next((p for p in self.peers if p.id() == target_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {target_index} not found")
-      if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
-      await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+    context = tracer.get_context(request_id)
+    if not context:
+      context = TraceContext(request_id=request_id, sequence_number=0)
+      tracer.set_context(request_id, context)
+
+    with tracer.start_span(
+      "forward_prompt",
+      context,
+      extra_attributes={
+        "source_node": self.id,
+        "target_index": target_index,
+        "prompt": prompt
+      }
+    ) as span:
+      if DEBUG >= 1: print(f"target partition index: {target_index}")
+      target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+      next_shard = self.get_current_shard(base_shard, target_index)
+      span.set_attribute("target_node", target_id)
+      
+      # Get trace context for propagation
+      trace_parent = tracer.inject_context(span)
+      
+      if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
+      if target_id == self.id:
+        # Update local context with trace parent
+        context.trace_parent = trace_parent
+        await self.process_prompt(next_shard, prompt, request_id)
+      else:
+        target_peer = next((p for p in self.peers if p.id() == target_id), None)
+        if not target_peer:
+          raise ValueError(f"Peer for {target_index} not found")
+        if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+        await target_peer.send_prompt(next_shard, prompt, request_id=request_id, sequence_number=context.sequence_number, trace_parent=trace_parent)
   
   
   async def forward_tensor(
   async def forward_tensor(
     self,
     self,
@@ -424,19 +512,39 @@ class Node:
     tensor: np.ndarray,
     tensor: np.ndarray,
     request_id: str,
     request_id: str,
     target_index: int,
     target_index: int,
-  ) -> None:
-    if DEBUG >= 1: print(f"target partition index: {target_index}")
-    target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
-    next_shard = self.get_current_shard(base_shard, target_index)
-    if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
-    if target_id == self.id:
-      await self.process_tensor(next_shard, tensor, request_id)
-    else:
-      target_peer = next((p for p in self.peers if p.id() == target_id), None)
-      if not target_peer:
-        raise ValueError(f"Peer for {target_index} not found")
-      if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
-      await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
+  ):
+    context = tracer.get_context(request_id)
+    if not context:
+      context = TraceContext(request_id=request_id, sequence_number=0)
+      tracer.set_context(request_id, context)
+
+    with tracer.start_span(
+      "forward_tensor",
+      context,
+      extra_attributes={
+        "source_node": self.id,
+        "target_index": target_index,
+        "tensor_shape": str(tensor.shape)
+      }
+    ) as span:
+      target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+      next_shard = self.get_current_shard(base_shard, target_index)
+      span.set_attribute("target_node", target_id)
+      
+      # Get trace context for propagation
+      trace_parent = tracer.inject_context(context.request_span or span)
+      
+      if target_id == self.id:
+        # Update local context with trace parent
+        context.trace_parent = trace_parent
+        await self.process_tensor(next_shard, tensor, request_id)
+      else:
+        target_peer = next((p for p in self.peers if p.id() == target_id), None)
+        if not target_peer:
+          raise ValueError(f"Peer for {target_index} not found")
+        
+        if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+        await target_peer.send_tensor(next_shard, tensor, request_id=request_id, sequence_number=context.sequence_number, trace_parent=trace_parent)
 
 
   def get_partition_index(self, offset: int = 0):
   def get_partition_index(self, offset: int = 0):
     if not self.partitioning_strategy:
     if not self.partitioning_strategy:
@@ -570,20 +678,32 @@ class Node:
     return self._on_opaque_status
     return self._on_opaque_status
 
 
   def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None:
   def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None:
-    if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}")
+    if DEBUG >= 2: print(f"[Node] Triggering token callbacks: {request_id=} {token=} {is_finished=}")
     self.on_token.trigger_all(request_id, token, is_finished)
     self.on_token.trigger_all(request_id, token, is_finished)
   
   
-  async def broadcast_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
-    async def send_new_token_to_peer(peer):
-      try:
-        await asyncio.wait_for(peer.send_new_token(request_id, token, is_finished), timeout=15.0)
-      except asyncio.TimeoutError:
-        print(f"Timeout broadcasting new token to {peer.id()}")
-      except Exception as e:
-        print(f"Error broadcasting new token to {peer.id()}: {e}")
-        traceback.print_exc()
-
-    await asyncio.gather(*[send_new_token_to_peer(peer) for peer in self.peers], return_exceptions=True)
+  async def broadcast_new_token(self, request_id: str, token: int, is_finished: bool):
+    """Broadcast a new token to all peers."""
+    context = tracer.get_context(request_id)
+    if context:
+      # Handle token in tracer for grouping
+      tracer.handle_token(context, token, is_finished)
+      # Get current trace context for propagation
+      trace_parent = ""
+      if context.current_span:
+        trace_parent = tracer.inject_context(context.current_span)
+    
+    tasks = []
+    for peer in self.peers:
+      tasks.append(
+        peer.send_new_token(
+          request_id,
+          token,
+          is_finished,
+          context.sequence_number if context else 0,
+          trace_parent if context else ""
+        )
+      )
+    await asyncio.gather(*tasks)
 
 
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
   async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
     if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")
     if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")

+ 166 - 0
exo/orchestration/tracing.py

@@ -0,0 +1,166 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Any
+from opentelemetry import trace, context
+from opentelemetry.trace import Status, StatusCode, SpanContext
+from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
+from contextlib import contextmanager
+import time
+from threading import Lock
+
+@dataclass
+class TraceContext:
+  request_id: str
+  sequence_number: int
+  current_span: Optional[trace.Span] = None
+  trace_parent: Optional[str] = None
+  token_group_span: Optional[trace.Span] = None
+  token_count: int = 0
+  token_group_size: int = 10  # Default group size
+  request_span: Optional[trace.Span] = None  # Track the main request span
+
+class Tracer:
+  def __init__(self):
+    self.tracer = trace.get_tracer("exo")
+    self.contexts: Dict[str, TraceContext] = {}
+    self._lock = Lock()
+    self.propagator = TraceContextTextMapPropagator()
+    
+  def get_context(self, request_id: str) -> Optional[TraceContext]:
+    with self._lock:
+      return self.contexts.get(request_id)
+
+  def set_context(self, request_id: str, context: TraceContext):
+    with self._lock:
+      self.contexts[request_id] = context
+
+  def inject_context(self, span: trace.Span) -> str:
+    """Inject current span context into carrier for propagation"""
+    carrier = {}
+    ctx = trace.set_span_in_context(span)
+    self.propagator.inject(carrier, context=ctx)
+    return carrier.get("traceparent", "")
+
+  def extract_context(self, trace_parent: str) -> Optional[context.Context]:
+    """Extract span context from carrier"""
+    if not trace_parent:
+      return None
+    carrier = {"traceparent": trace_parent}
+    return self.propagator.extract(carrier)
+
+  def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
+    """Create a new context with the given trace parent"""
+    parent_ctx = self.extract_context(trace_parent)
+    if parent_ctx:
+      # Create a new request span that links to the parent context
+      request_span = self.tracer.start_span(
+        "request",
+        context=parent_ctx,
+        attributes={
+          "request_id": request_id,
+          "sequence_number": sequence_number
+        }
+      )
+      return TraceContext(
+        request_id=request_id,
+        sequence_number=sequence_number,
+        request_span=request_span,
+        current_span=request_span,
+        trace_parent=trace_parent
+      )
+    return TraceContext(request_id=request_id, sequence_number=sequence_number)
+
+  def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
+    """Handle token generation and manage token group spans"""
+    context.token_count += 1
+    
+    # Start a new token group span if needed
+    if not context.token_group_span and context.request_span:
+      group_number = (context.token_count - 1) // context.token_group_size + 1
+      
+      # Create token group span as child of request span
+      parent_ctx = trace.set_span_in_context(context.request_span)
+      context.token_group_span = self.tracer.start_span(
+        f"token_group_{group_number}",
+        context=parent_ctx,
+        attributes={
+          "request_id": context.request_id,
+          "group.number": group_number,
+          "group.start_token": context.token_count,
+          "group.max_tokens": context.token_group_size
+        }
+      )
+    
+    # Add token to current group span
+    if context.token_group_span:
+      relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
+      context.token_group_span.set_attribute(f"token.{relative_pos}", token)
+      context.token_group_span.set_attribute("token.count", relative_pos)
+      
+      # End current group span if we've reached the group size or if generation is finished
+      if context.token_count % context.token_group_size == 0 or is_finished:
+        context.token_group_span.set_attribute("token.final_count", relative_pos)
+        context.token_group_span.end()
+        context.token_group_span = None
+
+  @contextmanager
+  def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
+    """Start a new span with proper parent context"""
+    attributes = {
+      "request_id": context.request_id,
+      "sequence_number": context.sequence_number
+    }
+    if extra_attributes:
+      attributes.update(extra_attributes)
+      
+    # Use request span as parent if available
+    parent_ctx = None
+    if context.request_span:
+      parent_ctx = trace.set_span_in_context(context.request_span)
+    elif context.trace_parent:
+      parent_ctx = self.extract_context(context.trace_parent)
+      if parent_ctx and not context.request_span:
+        # Create a new request span that links to the parent context
+        context.request_span = self.tracer.start_span(
+          "request",
+          context=parent_ctx,
+          attributes={
+            "request_id": context.request_id,
+            "sequence_number": context.sequence_number
+          }
+        )
+        parent_ctx = trace.set_span_in_context(context.request_span)
+    elif context.current_span:
+      parent_ctx = trace.set_span_in_context(context.current_span)
+    
+    # Create span with parent context if it exists
+    if parent_ctx:
+      span = self.tracer.start_span(
+        name,
+        context=parent_ctx,
+        attributes=attributes
+      )
+    else:
+      span = self.tracer.start_span(
+        name,
+        attributes=attributes
+      )
+    
+    # Update context with current span
+    prev_span = context.current_span
+    context.current_span = span
+    
+    try:
+      start_time = time.perf_counter()
+      yield span
+      duration = time.perf_counter() - start_time
+      span.set_attribute("duration_s", duration)
+      span.set_status(Status(StatusCode.OK))
+    except Exception as e:
+      span.set_status(Status(StatusCode.ERROR, str(e)))
+      raise
+    finally:
+      span.end()
+      context.current_span = prev_span
+
+# Global tracer instance
+tracer = Tracer() 

+ 4 - 0
setup.py

@@ -16,6 +16,10 @@ install_requires = [
   "nuitka==2.5.1",
   "nuitka==2.5.1",
   "nvidia-ml-py==12.560.30",
   "nvidia-ml-py==12.560.30",
   "opencv-python==4.10.0.84",
   "opencv-python==4.10.0.84",
+  "opentelemetry-api==1.29.0",
+  "opentelemetry-sdk==1.29.0",
+  "opentelemetry-exporter-otlp==1.29.0",
+  "opentelemetry-instrumentation==0.50b0",
   "pillow==10.4.0",
   "pillow==10.4.0",
   "prometheus-client==0.20.0",
   "prometheus-client==0.20.0",
   "protobuf==5.28.1",
   "protobuf==5.28.1",

Некоторые файлы не были показаны из-за большого количества измененных файлов