|
|
@@ -16,6 +16,7 @@ from exo.viz.topology_viz import TopologyViz
|
|
|
from exo.download.hf.hf_helpers import RepoProgressEvent
|
|
|
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
|
+from exo.orchestration.tracing import tracer, TraceContext
|
|
|
|
|
|
class Node:
|
|
|
def __init__(
|
|
|
@@ -111,44 +112,79 @@ class Node:
|
|
|
def get_topology_inference_engines(self) -> List[List[str]]:
|
|
|
return self.topology_inference_engines_pool
|
|
|
|
|
|
- token_count = 0
|
|
|
- first_token_time = 0
|
|
|
async def process_inference_result(
|
|
|
self,
|
|
|
shard,
|
|
|
result: np.ndarray,
|
|
|
request_id: Optional[str] = None,
|
|
|
):
|
|
|
- if request_id not in self.buffered_token_output:
|
|
|
- self.buffered_token_output[request_id] = ([], False)
|
|
|
- is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
-
|
|
|
- if shard.is_last_layer() and not is_finished:
|
|
|
- self.token_count += 1
|
|
|
- if self.token_count == 1:
|
|
|
- self.first_token_time = time.perf_counter_ns()
|
|
|
- if self.token_count % 20 == 0:
|
|
|
- print(f"[{request_id}] TPS: {self.token_count / ((time.perf_counter_ns() - self.first_token_time) / 1e9)}")
|
|
|
-
|
|
|
- token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
|
|
|
- await self.inference_engine.ensure_shard(shard)
|
|
|
- self.buffered_token_output[request_id][0].append(token.item())
|
|
|
- is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
- forward = token.reshape(1, -1)
|
|
|
- self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
|
|
|
- asyncio.create_task(self.broadcast_new_token(request_id, token.item(), is_finished))
|
|
|
- else:
|
|
|
- forward = result
|
|
|
-
|
|
|
- if is_finished:
|
|
|
- self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
|
- self.outstanding_requests.pop(request_id)
|
|
|
- else:
|
|
|
- self.outstanding_requests[request_id] = "waiting"
|
|
|
- asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
|
|
|
+ context = tracer.get_context(request_id)
|
|
|
+ if not context:
|
|
|
+ context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0)
|
|
|
+ tracer.set_context(request_id, context)
|
|
|
|
|
|
- return np.array(self.buffered_token_output[request_id][0])
|
|
|
+ try:
|
|
|
+ with tracer.start_span(
|
|
|
+ f"process_inference_result.{self.get_partition_index()}",
|
|
|
+ context,
|
|
|
+ extra_attributes={
|
|
|
+ "partition_index": self.get_partition_index(),
|
|
|
+ "node_id": self.id,
|
|
|
+ "start_layer": shard.start_layer,
|
|
|
+ "end_layer": shard.end_layer
|
|
|
+ }
|
|
|
+ ):
|
|
|
+ if request_id not in self.buffered_token_output:
|
|
|
+ self.buffered_token_output[request_id] = ([], False)
|
|
|
+ is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
+
|
|
|
+ if shard.is_last_layer() and not is_finished:
|
|
|
+ token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
|
|
|
+ forward = token.reshape(1, -1)
|
|
|
+
|
|
|
+ # Increment sequence number for next forward pass
|
|
|
+ next_sequence = context.sequence_number + 1
|
|
|
+ # Create new context but preserve request span
|
|
|
+ next_context = TraceContext(
|
|
|
+ request_id=context.request_id,
|
|
|
+ sequence_number=next_sequence,
|
|
|
+ request_span=context.request_span # Preserve request span
|
|
|
+ )
|
|
|
+ tracer.set_context(request_id, next_context)
|
|
|
+
|
|
|
+ self.buffered_token_output[request_id][0].append(token.item())
|
|
|
+ is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished
|
|
|
+ self.trigger_on_token_callbacks(request_id, token.item(), is_finished)
|
|
|
+ await self.broadcast_new_token(request_id, token.item(), is_finished)
|
|
|
+
|
|
|
+ if not is_finished:
|
|
|
+ self.outstanding_requests[request_id] = "waiting"
|
|
|
+ asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
|
|
|
+ else:
|
|
|
+ forward = result
|
|
|
+ if not is_finished:
|
|
|
+ self.outstanding_requests[request_id] = "waiting"
|
|
|
+ asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
|
|
|
+
|
|
|
+ if is_finished:
|
|
|
+ # End the request span when generation is complete
|
|
|
+ if context.request_span:
|
|
|
+ context.request_span.set_attribute("total_tokens", len(self.buffered_token_output[request_id][0]))
|
|
|
+ context.request_span.end()
|
|
|
+ context.request_span = None
|
|
|
+ self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
|
+ self.outstanding_requests.pop(request_id)
|
|
|
+
|
|
|
+ return np.array(self.buffered_token_output[request_id][0])
|
|
|
+ except Exception as e:
|
|
|
+ if request_id in self.outstanding_requests:
|
|
|
+ self.outstanding_requests.pop(request_id)
|
|
|
+ # End request span on error
|
|
|
+ if context and context.request_span:
|
|
|
+ context.request_span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
|
+ context.request_span.end()
|
|
|
+ context.request_span = None
|
|
|
+ raise
|
|
|
|
|
|
async def process_prompt(
|
|
|
self,
|
|
|
@@ -195,18 +231,46 @@ class Node:
|
|
|
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
+
|
|
|
+ # Create or get trace context
|
|
|
+ context = tracer.get_context(request_id)
|
|
|
+ if not context:
|
|
|
+ # Create new context with request span
|
|
|
+ request_span = tracer.tracer.start_span(
|
|
|
+ "request",
|
|
|
+ attributes={
|
|
|
+ "request_id": request_id,
|
|
|
+ "prompt": prompt,
|
|
|
+ "node_id": self.id,
|
|
|
+ "request_type": "process_prompt"
|
|
|
+ }
|
|
|
+ )
|
|
|
+ context = TraceContext(
|
|
|
+ request_id=request_id,
|
|
|
+ sequence_number=0,
|
|
|
+ request_span=request_span,
|
|
|
+ current_span=request_span,
|
|
|
+ trace_parent=tracer.inject_context(request_span)
|
|
|
+ )
|
|
|
+ tracer.set_context(request_id, context)
|
|
|
+
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
|
|
|
|
- if not shard.is_first_layer():
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
|
- self.outstanding_requests[request_id] = "waiting"
|
|
|
- await self.forward_prompt(shard, prompt, request_id, 0)
|
|
|
- return None
|
|
|
+ try:
|
|
|
+ if not shard.is_first_layer():
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
|
+ self.outstanding_requests[request_id] = "waiting"
|
|
|
+ await self.forward_prompt(shard, prompt, request_id, 0)
|
|
|
+ return None
|
|
|
|
|
|
- self.outstanding_requests[request_id] = "processing"
|
|
|
- result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
|
|
- await self.process_inference_result(shard, result, request_id)
|
|
|
+ self.outstanding_requests[request_id] = "processing"
|
|
|
+ result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
|
|
+ await self.process_inference_result(shard, result, request_id)
|
|
|
+ except Exception as e:
|
|
|
+ if context.request_span:
|
|
|
+ context.request_span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
|
+ raise
|
|
|
|
|
|
async def enqueue_example(
|
|
|
self,
|
|
|
@@ -350,33 +414,36 @@ class Node:
|
|
|
base_shard: Shard,
|
|
|
tensor: np.ndarray,
|
|
|
request_id: Optional[str] = None,
|
|
|
- ) -> None:
|
|
|
- shard = self.get_current_shard(base_shard)
|
|
|
- start_time = time.perf_counter_ns()
|
|
|
- await self._process_tensor(shard, tensor, request_id)
|
|
|
- end_time = time.perf_counter_ns()
|
|
|
- elapsed_time_ns = end_time - start_time
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
|
|
|
-
|
|
|
- async def _process_tensor(
|
|
|
- self,
|
|
|
- base_shard: Shard,
|
|
|
- tensor: np.ndarray,
|
|
|
- request_id: Optional[str] = None,
|
|
|
- ) -> None:
|
|
|
- if request_id is None:
|
|
|
- request_id = str(uuid.uuid4())
|
|
|
- shard = self.get_current_shard(base_shard)
|
|
|
+ ):
|
|
|
+ context = tracer.get_context(request_id)
|
|
|
+ if not context:
|
|
|
+ context = TraceContext(request_id=request_id or str(uuid.uuid4()), sequence_number=0)
|
|
|
+ tracer.set_context(request_id, context)
|
|
|
|
|
|
try:
|
|
|
self.outstanding_requests[request_id] = "processing"
|
|
|
- result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
|
|
- await self.process_inference_result(shard, result, request_id)
|
|
|
+ with tracer.start_span(
|
|
|
+ f"process_tensor.{self.get_partition_index()}",
|
|
|
+ context,
|
|
|
+ extra_attributes={
|
|
|
+ "partition_index": self.get_partition_index(),
|
|
|
+ "node_id": self.id,
|
|
|
+ "start_layer": base_shard.start_layer,
|
|
|
+ "end_layer": base_shard.end_layer,
|
|
|
+ "tensor_shape": str(tensor.shape)
|
|
|
+ }
|
|
|
+ ):
|
|
|
+ result = await self.inference_engine.infer_tensor(request_id, base_shard, tensor)
|
|
|
+ await self.process_inference_result(base_shard, result, request_id)
|
|
|
except Exception as e:
|
|
|
- self.outstanding_requests.pop(request_id)
|
|
|
- print(f"Error processing tensor for shard {shard}: {e}")
|
|
|
+ if request_id in self.outstanding_requests:
|
|
|
+ self.outstanding_requests.pop(request_id)
|
|
|
+ if context and context.request_span:
|
|
|
+ context.request_span.set_status(Status(StatusCode.ERROR, str(e)))
|
|
|
+ print(f"Error processing tensor for shard {base_shard}: {e}")
|
|
|
traceback.print_exc()
|
|
|
-
|
|
|
+ raise
|
|
|
+
|
|
|
async def forward_example(
|
|
|
self,
|
|
|
base_shard: Shard,
|
|
|
@@ -405,18 +472,39 @@ class Node:
|
|
|
request_id: str,
|
|
|
target_index: int,
|
|
|
) -> None:
|
|
|
- if DEBUG >= 1: print(f"target partition index: {target_index}")
|
|
|
- target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
- next_shard = self.get_current_shard(base_shard, target_index)
|
|
|
- if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
|
|
|
- if target_id == self.id:
|
|
|
- await self.process_prompt(next_shard, prompt, request_id)
|
|
|
- else:
|
|
|
- target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
- if not target_peer:
|
|
|
- raise ValueError(f"Peer for {target_index} not found")
|
|
|
- if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
|
|
|
- await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
|
|
|
+ context = tracer.get_context(request_id)
|
|
|
+ if not context:
|
|
|
+ context = TraceContext(request_id=request_id, sequence_number=0)
|
|
|
+ tracer.set_context(request_id, context)
|
|
|
+
|
|
|
+ with tracer.start_span(
|
|
|
+ "forward_prompt",
|
|
|
+ context,
|
|
|
+ extra_attributes={
|
|
|
+ "source_node": self.id,
|
|
|
+ "target_index": target_index,
|
|
|
+ "prompt": prompt
|
|
|
+ }
|
|
|
+ ) as span:
|
|
|
+ if DEBUG >= 1: print(f"target partition index: {target_index}")
|
|
|
+ target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
+ next_shard = self.get_current_shard(base_shard, target_index)
|
|
|
+ span.set_attribute("target_node", target_id)
|
|
|
+
|
|
|
+ # Get trace context for propagation
|
|
|
+ trace_parent = tracer.inject_context(span)
|
|
|
+
|
|
|
+ if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
|
|
|
+ if target_id == self.id:
|
|
|
+ # Update local context with trace parent
|
|
|
+ context.trace_parent = trace_parent
|
|
|
+ await self.process_prompt(next_shard, prompt, request_id)
|
|
|
+ else:
|
|
|
+ target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
+ if not target_peer:
|
|
|
+ raise ValueError(f"Peer for {target_index} not found")
|
|
|
+ if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
|
|
|
+ await target_peer.send_prompt(next_shard, prompt, request_id=request_id, sequence_number=context.sequence_number, trace_parent=trace_parent)
|
|
|
|
|
|
async def forward_tensor(
|
|
|
self,
|
|
|
@@ -424,19 +512,39 @@ class Node:
|
|
|
tensor: np.ndarray,
|
|
|
request_id: str,
|
|
|
target_index: int,
|
|
|
- ) -> None:
|
|
|
- if DEBUG >= 1: print(f"target partition index: {target_index}")
|
|
|
- target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
- next_shard = self.get_current_shard(base_shard, target_index)
|
|
|
- if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
|
|
|
- if target_id == self.id:
|
|
|
- await self.process_tensor(next_shard, tensor, request_id)
|
|
|
- else:
|
|
|
- target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
- if not target_peer:
|
|
|
- raise ValueError(f"Peer for {target_index} not found")
|
|
|
- if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
|
|
|
- await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
|
|
|
+ ):
|
|
|
+ context = tracer.get_context(request_id)
|
|
|
+ if not context:
|
|
|
+ context = TraceContext(request_id=request_id, sequence_number=0)
|
|
|
+ tracer.set_context(request_id, context)
|
|
|
+
|
|
|
+ with tracer.start_span(
|
|
|
+ "forward_tensor",
|
|
|
+ context,
|
|
|
+ extra_attributes={
|
|
|
+ "source_node": self.id,
|
|
|
+ "target_index": target_index,
|
|
|
+ "tensor_shape": str(tensor.shape)
|
|
|
+ }
|
|
|
+ ) as span:
|
|
|
+ target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
+ next_shard = self.get_current_shard(base_shard, target_index)
|
|
|
+ span.set_attribute("target_node", target_id)
|
|
|
+
|
|
|
+ # Get trace context for propagation
|
|
|
+ trace_parent = tracer.inject_context(context.request_span or span)
|
|
|
+
|
|
|
+ if target_id == self.id:
|
|
|
+ # Update local context with trace parent
|
|
|
+ context.trace_parent = trace_parent
|
|
|
+ await self.process_tensor(next_shard, tensor, request_id)
|
|
|
+ else:
|
|
|
+ target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
+ if not target_peer:
|
|
|
+ raise ValueError(f"Peer for {target_index} not found")
|
|
|
+
|
|
|
+ if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
|
|
|
+ await target_peer.send_tensor(next_shard, tensor, request_id=request_id, sequence_number=context.sequence_number, trace_parent=trace_parent)
|
|
|
|
|
|
def get_partition_index(self, offset: int = 0):
|
|
|
if not self.partitioning_strategy:
|
|
|
@@ -570,20 +678,32 @@ class Node:
|
|
|
return self._on_opaque_status
|
|
|
|
|
|
def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None:
|
|
|
- if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}")
|
|
|
+ if DEBUG >= 2: print(f"[Node] Triggering token callbacks: {request_id=} {token=} {is_finished=}")
|
|
|
self.on_token.trigger_all(request_id, token, is_finished)
|
|
|
|
|
|
- async def broadcast_new_token(self, request_id: str, token: int, is_finished: bool) -> None:
|
|
|
- async def send_new_token_to_peer(peer):
|
|
|
- try:
|
|
|
- await asyncio.wait_for(peer.send_new_token(request_id, token, is_finished), timeout=15.0)
|
|
|
- except asyncio.TimeoutError:
|
|
|
- print(f"Timeout broadcasting new token to {peer.id()}")
|
|
|
- except Exception as e:
|
|
|
- print(f"Error broadcasting new token to {peer.id()}: {e}")
|
|
|
- traceback.print_exc()
|
|
|
-
|
|
|
- await asyncio.gather(*[send_new_token_to_peer(peer) for peer in self.peers], return_exceptions=True)
|
|
|
+ async def broadcast_new_token(self, request_id: str, token: int, is_finished: bool):
|
|
|
+ """Broadcast a new token to all peers."""
|
|
|
+ context = tracer.get_context(request_id)
|
|
|
+ if context:
|
|
|
+ # Handle token in tracer for grouping
|
|
|
+ tracer.handle_token(context, token, is_finished)
|
|
|
+ # Get current trace context for propagation
|
|
|
+ trace_parent = ""
|
|
|
+ if context.current_span:
|
|
|
+ trace_parent = tracer.inject_context(context.current_span)
|
|
|
+
|
|
|
+ tasks = []
|
|
|
+ for peer in self.peers:
|
|
|
+ tasks.append(
|
|
|
+ peer.send_new_token(
|
|
|
+ request_id,
|
|
|
+ token,
|
|
|
+ is_finished,
|
|
|
+ context.sequence_number if context else 0,
|
|
|
+ trace_parent if context else ""
|
|
|
+ )
|
|
|
+ )
|
|
|
+ await asyncio.gather(*tasks)
|
|
|
|
|
|
async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
|
|
|
if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")
|