|
@@ -69,7 +69,7 @@ class StandardNode(Node):
|
|
await self.discovery.stop()
|
|
await self.discovery.stop()
|
|
await self.server.stop()
|
|
await self.server.stop()
|
|
|
|
|
|
- async def process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
|
|
|
+ async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
shard = self.get_current_shard(base_shard)
|
|
shard = self.get_current_shard(base_shard)
|
|
asyncio.create_task(
|
|
asyncio.create_task(
|
|
self.broadcast_opaque_status(
|
|
self.broadcast_opaque_status(
|
|
@@ -82,6 +82,7 @@ class StandardNode(Node):
|
|
"base_shard": base_shard.to_dict(),
|
|
"base_shard": base_shard.to_dict(),
|
|
"shard": shard.to_dict(),
|
|
"shard": shard.to_dict(),
|
|
"prompt": prompt,
|
|
"prompt": prompt,
|
|
|
|
+ "image_str": image_str,
|
|
"inference_state": inference_state,
|
|
"inference_state": inference_state,
|
|
"request_id": request_id,
|
|
"request_id": request_id,
|
|
}
|
|
}
|
|
@@ -89,7 +90,7 @@ class StandardNode(Node):
|
|
)
|
|
)
|
|
)
|
|
)
|
|
start_time = time.perf_counter_ns()
|
|
start_time = time.perf_counter_ns()
|
|
- resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
|
|
|
|
|
|
+ resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
|
|
end_time = time.perf_counter_ns()
|
|
end_time = time.perf_counter_ns()
|
|
elapsed_time_ns = end_time - start_time
|
|
elapsed_time_ns = end_time - start_time
|
|
asyncio.create_task(
|
|
asyncio.create_task(
|
|
@@ -103,6 +104,7 @@ class StandardNode(Node):
|
|
"base_shard": base_shard.to_dict(),
|
|
"base_shard": base_shard.to_dict(),
|
|
"shard": shard.to_dict(),
|
|
"shard": shard.to_dict(),
|
|
"prompt": prompt,
|
|
"prompt": prompt,
|
|
|
|
+ "image_str": image_str,
|
|
"inference_state": inference_state,
|
|
"inference_state": inference_state,
|
|
"request_id": request_id,
|
|
"request_id": request_id,
|
|
"elapsed_time_ns": elapsed_time_ns,
|
|
"elapsed_time_ns": elapsed_time_ns,
|
|
@@ -113,20 +115,20 @@ class StandardNode(Node):
|
|
)
|
|
)
|
|
return resp
|
|
return resp
|
|
|
|
|
|
- async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
|
|
|
|
+ async def _process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
|
|
if request_id is None:
|
|
if request_id is None:
|
|
request_id = str(uuid.uuid4())
|
|
request_id = str(uuid.uuid4())
|
|
if request_id not in self.buffered_token_output:
|
|
if request_id not in self.buffered_token_output:
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
shard = self.get_current_shard(base_shard)
|
|
shard = self.get_current_shard(base_shard)
|
|
|
|
|
|
- if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
|
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {image_str=}")
|
|
if shard.start_layer != 0:
|
|
if shard.start_layer != 0:
|
|
- if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
|
|
- await self.forward_to_next_shard(shard, prompt, request_id)
|
|
|
|
|
|
+ if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=} {image_str=}")
|
|
|
|
+ await self.forward_to_next_shard(shard, prompt, request_id, image_str)
|
|
return
|
|
return
|
|
|
|
|
|
- result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state=inference_state)
|
|
|
|
|
|
+ result, inference_state, is_finished = await self.inference_engine.infer_prompt(request_id, shard, prompt, image_str, inference_state=inference_state)
|
|
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
if is_finished:
|
|
if is_finished:
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
@@ -234,6 +236,7 @@ class StandardNode(Node):
|
|
base_shard: Shard,
|
|
base_shard: Shard,
|
|
tensor_or_prompt: Union[np.ndarray, str],
|
|
tensor_or_prompt: Union[np.ndarray, str],
|
|
request_id: str,
|
|
request_id: str,
|
|
|
|
+ image_str: Optional[str] = None,
|
|
inference_state: Optional[str] = None,
|
|
inference_state: Optional[str] = None,
|
|
) -> None:
|
|
) -> None:
|
|
if not self.partitioning_strategy:
|
|
if not self.partitioning_strategy:
|
|
@@ -255,7 +258,7 @@ class StandardNode(Node):
|
|
if isinstance(tensor_or_prompt, np.ndarray):
|
|
if isinstance(tensor_or_prompt, np.ndarray):
|
|
await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
await self.process_tensor(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
else:
|
|
else:
|
|
- await self.process_prompt(shard, tensor_or_prompt, request_id, inference_state=inference_state)
|
|
|
|
|
|
+ await self.process_prompt(shard, tensor_or_prompt, image_str, request_id, inference_state=inference_state)
|
|
return
|
|
return
|
|
|
|
|
|
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
target_peer = next((p for p in self.peers if p.id() == next_partition.node_id), None)
|
|
@@ -267,7 +270,7 @@ class StandardNode(Node):
|
|
if isinstance(tensor_or_prompt, np.ndarray):
|
|
if isinstance(tensor_or_prompt, np.ndarray):
|
|
await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
|
|
await target_peer.send_tensor(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
|
|
else:
|
|
else:
|
|
- await target_peer.send_prompt(next_shard, tensor_or_prompt, request_id=request_id, inference_state=inference_state)
|
|
|
|
|
|
+ await target_peer.send_prompt(next_shard, tensor_or_prompt, image_str=image_str, request_id=request_id, inference_state=inference_state)
|
|
|
|
|
|
def get_current_shard(self, base_shard: Shard) -> Shard:
|
|
def get_current_shard(self, base_shard: Shard) -> Shard:
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|
|
partitions = self.partitioning_strategy.partition(self.topology)
|