|
@@ -111,7 +111,6 @@ class StandardNode(Node):
|
|
shard,
|
|
shard,
|
|
result: np.ndarray,
|
|
result: np.ndarray,
|
|
request_id: Optional[str] = None,
|
|
request_id: Optional[str] = None,
|
|
- inference_state: Optional[str] = None,
|
|
|
|
):
|
|
):
|
|
if request_id not in self.buffered_token_output:
|
|
if request_id not in self.buffered_token_output:
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
@@ -123,7 +122,6 @@ class StandardNode(Node):
|
|
|
|
|
|
if shard.is_last_layer():
|
|
if shard.is_last_layer():
|
|
result = await self.inference_engine.sample(result)
|
|
result = await self.inference_engine.sample(result)
|
|
- inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id]) + 1})
|
|
|
|
|
|
|
|
await self.inference_engine.ensure_shard(shard)
|
|
await self.inference_engine.ensure_shard(shard)
|
|
is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
@@ -139,7 +137,7 @@ class StandardNode(Node):
|
|
if is_finished:
|
|
if is_finished:
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
else:
|
|
else:
|
|
- asyncio.create_task(self.forward_to_next_shard(shard, result, request_id, inference_state=inference_state))
|
|
|
|
|
|
+ asyncio.create_task(self.forward_to_next_shard(shard, result, request_id))
|
|
|
|
|
|
return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
return np.array(self.buffered_token_output[request_id][0]) if len(self.buffered_token_output[request_id][0]) > 0 else None
|
|
|
|
|
|
@@ -148,7 +146,6 @@ class StandardNode(Node):
|
|
base_shard: Shard,
|
|
base_shard: Shard,
|
|
prompt: str,
|
|
prompt: str,
|
|
request_id: Optional[str] = None,
|
|
request_id: Optional[str] = None,
|
|
- inference_state: Optional[str] = None
|
|
|
|
) -> Optional[np.ndarray]:
|
|
) -> Optional[np.ndarray]:
|
|
shard = self.get_current_shard(base_shard)
|
|
shard = self.get_current_shard(base_shard)
|
|
asyncio.create_task(
|
|
asyncio.create_task(
|
|
@@ -161,13 +158,12 @@ class StandardNode(Node):
|
|
"base_shard": base_shard.to_dict(),
|
|
"base_shard": base_shard.to_dict(),
|
|
"shard": shard.to_dict(),
|
|
"shard": shard.to_dict(),
|
|
"prompt": prompt,
|
|
"prompt": prompt,
|
|
- "inference_state": inference_state,
|
|
|
|
"request_id": request_id,
|
|
"request_id": request_id,
|
|
}),
|
|
}),
|
|
)
|
|
)
|
|
)
|
|
)
|
|
start_time = time.perf_counter_ns()
|
|
start_time = time.perf_counter_ns()
|
|
- resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
|
|
|
|
|
|
+ resp = await self._process_prompt(base_shard, prompt, request_id)
|
|
end_time = time.perf_counter_ns()
|
|
end_time = time.perf_counter_ns()
|
|
elapsed_time_ns = end_time - start_time
|
|
elapsed_time_ns = end_time - start_time
|
|
asyncio.create_task(
|
|
asyncio.create_task(
|
|
@@ -180,7 +176,6 @@ class StandardNode(Node):
|
|
"base_shard": base_shard.to_dict(),
|
|
"base_shard": base_shard.to_dict(),
|
|
"shard": shard.to_dict(),
|
|
"shard": shard.to_dict(),
|
|
"prompt": prompt,
|
|
"prompt": prompt,
|
|
- "inference_state": inference_state,
|
|
|
|
"request_id": request_id,
|
|
"request_id": request_id,
|
|
"elapsed_time_ns": elapsed_time_ns,
|
|
"elapsed_time_ns": elapsed_time_ns,
|
|
"result_size": resp.size if resp is not None else 0,
|
|
"result_size": resp.size if resp is not None else 0,
|
|
@@ -189,7 +184,7 @@ class StandardNode(Node):
|
|
)
|
|
)
|
|
return resp
|
|
return resp
|
|
|
|
|
|
- async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
|
|
|
+ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
|
if request_id is None:
|
|
if request_id is None:
|
|
request_id = str(uuid.uuid4())
|
|
request_id = str(uuid.uuid4())
|
|
shard = self.get_current_shard(base_shard)
|
|
shard = self.get_current_shard(base_shard)
|
|
@@ -197,11 +192,11 @@ class StandardNode(Node):
|
|
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
if shard.start_layer != 0:
|
|
if shard.start_layer != 0:
|
|
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
- await self.forward_to_next_shard(shard, prompt, request_id, inference_state=inference_state)
|
|
|
|
|
|
+ await self.forward_to_next_shard(shard, prompt, request_id)
|
|
return None
|
|
return None
|
|
else:
|
|
else:
|
|
- result = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
|
|
|
|
- ret = await self.process_result(shard, result, request_id, inference_state=inference_state)
|
|
|
|
|
|
+ result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
|
|
|
+ ret = await self.process_result(shard, result, request_id)
|
|
return result
|
|
return result
|
|
|
|
|
|
async def process_tensor(
|
|
async def process_tensor(
|
|
@@ -209,7 +204,6 @@ class StandardNode(Node):
|
|
base_shard: Shard,
|
|
base_shard: Shard,
|
|
tensor: np.ndarray,
|
|
tensor: np.ndarray,
|
|
request_id: Optional[str] = None,
|
|
request_id: Optional[str] = None,
|
|
- inference_state: Optional[str] = None,
|
|
|
|
) -> Optional[np.ndarray]:
|
|
) -> Optional[np.ndarray]:
|
|
shard = self.get_current_shard(base_shard)
|
|
shard = self.get_current_shard(base_shard)
|
|
asyncio.create_task(
|
|
asyncio.create_task(
|
|
@@ -224,12 +218,11 @@ class StandardNode(Node):
|
|
"tensor_size": tensor.size,
|
|
"tensor_size": tensor.size,
|
|
"tensor_shape": tensor.shape,
|
|
"tensor_shape": tensor.shape,
|
|
"request_id": request_id,
|
|
"request_id": request_id,
|
|
- "inference_state": inference_state,
|
|
|
|
}),
|
|
}),
|
|
)
|
|
)
|
|
)
|
|
)
|
|
start_time = time.perf_counter_ns()
|
|
start_time = time.perf_counter_ns()
|
|
- resp = await self._process_tensor(shard, tensor, request_id, inference_state)
|
|
|
|
|
|
+ resp = await self._process_tensor(shard, tensor, request_id)
|
|
end_time = time.perf_counter_ns()
|
|
end_time = time.perf_counter_ns()
|
|
elapsed_time_ns = end_time - start_time
|
|
elapsed_time_ns = end_time - start_time
|
|
asyncio.create_task(
|
|
asyncio.create_task(
|
|
@@ -254,7 +247,6 @@ class StandardNode(Node):
|
|
base_shard: Shard,
|
|
base_shard: Shard,
|
|
tensor: np.ndarray,
|
|
tensor: np.ndarray,
|
|
request_id: Optional[str] = None,
|
|
request_id: Optional[str] = None,
|
|
- inference_state: Optional[str] = None,
|
|
|
|
) -> Optional[np.ndarray]:
|
|
) -> Optional[np.ndarray]:
|
|
if request_id is None:
|
|
if request_id is None:
|
|
request_id = str(uuid.uuid4())
|
|
request_id = str(uuid.uuid4())
|
|
@@ -262,8 +254,8 @@ class StandardNode(Node):
|
|
|
|
|
|
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
try:
|
|
try:
|
|
- result = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state=inference_state)
|
|
|
|
- ret = await self.process_result(shard, result, request_id, inference_state=inference_state)
|
|
|
|
|
|
+ result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
|
|
|
+ ret = await self.process_result(shard, result, request_id)
|
|
return ret
|
|
return ret
|
|
except Exception as e:
|
|
except Exception as e:
|
|
print(f"Error processing tensor for shard {shard}: {e}")
|
|
print(f"Error processing tensor for shard {shard}: {e}")
|
|
@@ -275,7 +267,6 @@ class StandardNode(Node):
|
|
base_shard: Shard,
|
|
base_shard: Shard,
|
|
tensor_or_prompt: Union[np.ndarray, str],
|
|
tensor_or_prompt: Union[np.ndarray, str],
|
|
request_id: str,
|
|
request_id: str,
|
|
- inference_state: Optional[str] = None,
|
|
|
|
) -> None:
|
|
) -> None:
|
|
if not self.partitioning_strategy:
|
|
if not self.partitioning_strategy:
|
|
if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
|
|
if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
|
|
@@ -290,18 +281,18 @@ class StandardNode(Node):
|
|
is_tensor = isinstance(tensor_or_prompt, np.ndarray)
|
|
is_tensor = isinstance(tensor_or_prompt, np.ndarray)
|
|
if target_id == self.id:
|
|
if target_id == self.id:
|
|
if is_tensor:
|
|
if is_tensor:
|
|
- await self.process_tensor(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
|
|
|
+ await self.process_tensor(next_shard, tensor_or_prompt, request_id)
|
|
else:
|
|
else:
|
|
- await self.process_prompt(next_shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
|
|
|
+ await self.process_prompt(next_shard, tensor_or_prompt, request_id)
|
|
else:
|
|
else:
|
|
target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
if not target_peer:
|
|
if not target_peer:
|
|
raise ValueError(f"Peer for {next_partition_index} not found")
|
|
raise ValueError(f"Peer for {next_partition_index} not found")
|
|
if is_tensor:
|
|
if is_tensor:
|
|
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
|
|
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor_or_prompt}")
|
|
- await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
|
|
|
|
|
|
+ await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id)
|
|
else:
|
|
else:
|
|
- await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
|
|
|
|
|
|
+ await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id)
|
|
|
|
|
|
def get_partition_index(self, offset: int = 0):
|
|
def get_partition_index(self, offset: int = 0):
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|