|
@@ -102,11 +102,7 @@ class StandardNode(Node):
|
|
def get_topology_inference_engines(self) -> List[List[str]]:
|
|
def get_topology_inference_engines(self) -> List[List[str]]:
|
|
return self.topology_inference_engines_pool
|
|
return self.topology_inference_engines_pool
|
|
|
|
|
|
- async def encode_prompt(self, shard: Shard, prompt):
|
|
|
|
- toks = await self.inference_engine.encode(shard, prompt)
|
|
|
|
- return toks
|
|
|
|
-
|
|
|
|
- async def process_result(
|
|
|
|
|
|
+ async def process_inference_result(
|
|
self,
|
|
self,
|
|
shard,
|
|
shard,
|
|
result: np.ndarray,
|
|
result: np.ndarray,
|
|
@@ -114,32 +110,25 @@ class StandardNode(Node):
|
|
):
|
|
):
|
|
if request_id not in self.buffered_token_output:
|
|
if request_id not in self.buffered_token_output:
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
-
|
|
|
|
- if request_id not in self.buffered_logits:
|
|
|
|
- self.buffered_logits[request_id] = []
|
|
|
|
-
|
|
|
|
- self.buffered_logits[request_id] += [i for i in np.reshape(result, (-1, 1, result.shape[-1]))]
|
|
|
|
-
|
|
|
|
- if shard.is_last_layer():
|
|
|
|
- result = await self.inference_engine.sample(result)
|
|
|
|
-
|
|
|
|
- await self.inference_engine.ensure_shard(shard)
|
|
|
|
- is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
|
-
|
|
|
|
- asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) # TODO: this is n^2 communication complexity
|
|
|
|
-
|
|
|
|
- if result.size == 1: # we got a new token out
|
|
|
|
- self.buffered_token_output[request_id][0].append(result.item())
|
|
|
|
|
|
+ is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
|
+ if shard.is_last_layer() and not is_finished:
|
|
|
|
+ token = await self.inference_engine.sample(result)
|
|
|
|
+ await self.inference_engine.ensure_shard(shard)
|
|
|
|
+ self.buffered_token_output[request_id][0].append(token.item())
|
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
|
+ is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
|
|
|
|
+ forward = token.reshape(1, -1)
|
|
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
-
|
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
|
|
|
+ asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
|
|
|
|
+ else:
|
|
|
|
+ forward = result
|
|
|
|
|
|
if is_finished:
|
|
if is_finished:
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
else:
|
|
else:
|
|
- asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
|
|
|
|
|
|
+ asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
|
|
|
|
|
|
- return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
|
|
|
|
+ return np.array(self.buffered_token_output[request_id][0])
|
|
|
|
|
|
async def process_prompt(
|
|
async def process_prompt(
|
|
self,
|
|
self,
|
|
@@ -190,13 +179,13 @@ class StandardNode(Node):
|
|
shard = self.get_current_shard(base_shard)
|
|
shard = self.get_current_shard(base_shard)
|
|
|
|
|
|
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
- if shard.start_layer != 0:
|
|
|
|
|
|
+ if not shard.is_first_layer():
|
|
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
- await self.forward_to_next_shard(shard, prompt, request_id)
|
|
|
|
|
|
+ resp = await self.forward_prompt(shard, prompt, request_id, 0)
|
|
return None
|
|
return None
|
|
else:
|
|
else:
|
|
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
|
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
|
- ret = await self.process_result(shard, result, request_id)
|
|
|
|
|
|
+ ret = await self.process_inference_result(shard, result, request_id)
|
|
return result
|
|
return result
|
|
|
|
|
|
async def process_tensor(
|
|
async def process_tensor(
|
|
@@ -255,46 +244,57 @@ class StandardNode(Node):
|
|
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
try:
|
|
try:
|
|
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
|
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
|
- ret = await self.process_result(shard, result, request_id)
|
|
|
|
|
|
+ ret = await self.process_inference_result(shard, result, request_id)
|
|
return ret
|
|
return ret
|
|
except Exception as e:
|
|
except Exception as e:
|
|
print(f"Error processing tensor for shard {shard}: {e}")
|
|
print(f"Error processing tensor for shard {shard}: {e}")
|
|
traceback.print_exc()
|
|
traceback.print_exc()
|
|
return None
|
|
return None
|
|
|
|
|
|
- async def forward_to_next_shard(
|
|
|
|
|
|
+ async def forward_prompt(
|
|
self,
|
|
self,
|
|
base_shard: Shard,
|
|
base_shard: Shard,
|
|
- tensor_or_prompt: Union[np.ndarray, str],
|
|
|
|
|
|
+ prompt: str,
|
|
request_id: str,
|
|
request_id: str,
|
|
|
|
+ target_index: int,
|
|
) -> None:
|
|
) -> None:
|
|
- if not self.partitioning_strategy:
|
|
|
|
- if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
|
|
|
|
- return
|
|
|
|
-
|
|
|
|
- next_partition_index = self.get_partition_index(offset = 1)
|
|
|
|
- if DEBUG >= 1: print(f"Next partition index: {next_partition_index}")
|
|
|
|
- if next_partition_index is not None:
|
|
|
|
- target_id = self.partitioning_strategy.partition(self.topology)[next_partition_index].node_id
|
|
|
|
- next_shard = self.get_current_shard(base_shard, next_partition_index)
|
|
|
|
- if DEBUG >= 2: print(f"Computed next from: {base_shard} {next_partition_index}, {self.topology}. Next shard: {next_shard}")
|
|
|
|
- is_tensor = isinstance(tensor_or_prompt, np.ndarray)
|
|
|
|
- if target_id == self.id:
|
|
|
|
- if is_tensor:
|
|
|
|
- await self.process_tensor(next_shard, tensor_or_prompt, request_id)
|
|
|
|
- else:
|
|
|
|
- await self.process_prompt(next_shard, tensor_or_prompt, request_id)
|
|
|
|
- else:
|
|
|
|
- target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
|
- if not target_peer:
|
|
|
|
- raise ValueError(f"Peer for {next_partition_index} not found")
|
|
|
|
- if is_tensor:
|
|
|
|
- if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
|
|
|
|
- await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id)
|
|
|
|
- else:
|
|
|
|
- await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id)
|
|
|
|
|
|
+ if DEBUG >= 1: print(f"target partition index: {target_index}")
|
|
|
|
+ target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
|
+ next_shard = self.get_current_shard(base_shard, target_index)
|
|
|
|
+ if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
|
|
|
|
+ if target_id == self.id:
|
|
|
|
+ await self.process_prompt(next_shard, prompt, request_id)
|
|
|
|
+ else:
|
|
|
|
+ target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
|
+ if not target_peer:
|
|
|
|
+ raise ValueError(f"Peer for {target_index} not found")
|
|
|
|
+ if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
|
|
|
|
+ await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
|
|
|
|
+
|
|
|
|
+ async def forward_tensor(
|
|
|
|
+ self,
|
|
|
|
+ base_shard: Shard,
|
|
|
|
+ tensor: np.ndarray,
|
|
|
|
+ request_id: str,
|
|
|
|
+ target_index: int,
|
|
|
|
+ ) -> None:
|
|
|
|
+ if DEBUG >= 1: print(f"target partition index: {target_index}")
|
|
|
|
+ target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
|
+ next_shard = self.get_current_shard(base_shard, target_index)
|
|
|
|
+ if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
|
|
|
|
+ if target_id == self.id:
|
|
|
|
+ await self.process_tensor(next_shard, tensor, request_id)
|
|
|
|
+ else:
|
|
|
|
+ target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
|
+ if not target_peer:
|
|
|
|
+ raise ValueError(f"Peer for {target_index} not found")
|
|
|
|
+ if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
|
|
|
|
+ await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
|
|
|
|
|
|
def get_partition_index(self, offset: int = 0):
|
|
def get_partition_index(self, offset: int = 0):
|
|
|
|
+ if not self.partitioning_strategy:
|
|
|
|
+ if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
|
|
|
|
+ return None
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
|
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
|
|
if current_partition_index is None:
|
|
if current_partition_index is None:
|