|
@@ -122,7 +122,6 @@ class StandardNode(Node):
|
|
|
await self.inference_engine.ensure_shard(shard)
|
|
|
self.buffered_token_output[request_id][0].append(token.item())
|
|
|
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
- self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
|
|
|
forward = token.reshape(1, -1)
|
|
@@ -211,13 +210,14 @@ class StandardNode(Node):
|
|
|
):
|
|
|
shard = self.get_current_shard(base_shard)
|
|
|
if shard.is_first_layer():
|
|
|
- resp = await self.process_example(shard, example, target, length, train, request_id)
|
|
|
+ loss = await self.process_example(shard, example, target, length, train, request_id)
|
|
|
+ return loss
|
|
|
else:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
self.outstanding_requests[request_id] = "waiting"
|
|
|
- resp = await self.forward_example(shard, example, target, length, train, request_id, 0)
|
|
|
- return resp
|
|
|
+ loss = await self.forward_example(shard, example, target, length, train, request_id, 0)
|
|
|
+ return loss
|
|
|
|
|
|
async def coordinate_save(
|
|
|
self,
|
|
@@ -279,7 +279,6 @@ class StandardNode(Node):
|
|
|
"shard": shard.to_dict(),
|
|
|
"request_id": request_id,
|
|
|
"elapsed_time_ns": elapsed_time_ns,
|
|
|
- "result_size": resp.size if resp is not None else 0,
|
|
|
}),
|
|
|
)
|
|
|
)
|
|
@@ -308,11 +307,14 @@ class StandardNode(Node):
|
|
|
self.outstanding_requests[request_id] = "preprocessing"
|
|
|
step = await self.inference_engine.infer_tensor(request_id, shard, example)
|
|
|
self.outstanding_requests[request_id] = "waiting"
|
|
|
- backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
|
|
|
+ loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
|
|
|
self.outstanding_requests[request_id] = "training"
|
|
|
- loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
|
|
|
+ partial_loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
|
|
|
self.outstanding_requests.pop(request_id)
|
|
|
- return loss.reshape(1, -1) if shard.is_first_layer() else grad
|
|
|
+ if shard.is_first_layer():
|
|
|
+ return loss
|
|
|
+ else:
|
|
|
+ return loss, grad
|
|
|
else:
|
|
|
if shard.is_last_layer():
|
|
|
self.outstanding_requests[request_id] = "evaluating"
|
|
@@ -323,7 +325,7 @@ class StandardNode(Node):
|
|
|
self.outstanding_requests[request_id] = "waiting"
|
|
|
loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
|
|
|
self.outstanding_requests.pop(request_id)
|
|
|
- return loss.reshape(1, -1)
|
|
|
+ return loss
|
|
|
except Exception as e:
|
|
|
self.outstanding_requests.pop(request_id)
|
|
|
print(f"Error processing example for shard {shard}: {e}")
|
|
@@ -413,25 +415,8 @@ class StandardNode(Node):
|
|
|
if not target_peer:
|
|
|
raise ValueError(f"peer for {target_index} not found")
|
|
|
if DEBUG >= 1: print(f"sending example to {target_peer.id()}: {step} => {target} ({length})")
|
|
|
- ret = await target_peer.send_example(target_shard, step, target, length, request_id=request_id, train=train)
|
|
|
- return ret
|
|
|
-
|
|
|
- async def forward_loss(
|
|
|
- self,
|
|
|
- base_shard: Shard,
|
|
|
- loss: np.ndarray,
|
|
|
- request_id: str,
|
|
|
- target_index: int,
|
|
|
- ) -> None:
|
|
|
- if DEBUG >= 1: print(f"target partition index: {target_index}")
|
|
|
- target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
|
|
- target_shard = self.get_current_shard(base_shard, target_index)
|
|
|
- if DEBUG >= 2: print(f"computed target from: {base_shard} {target_index}, {self.topology}. target shard: {target_shard}")
|
|
|
- target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
|
|
- if not target_peer:
|
|
|
- raise ValueError(f"peer for {target_index} not found")
|
|
|
- if DEBUG >= 1: print(f"sending tensor to {target_peer.id()}: {loss}")
|
|
|
- await target_peer.send_loss(target_shard, step, target, length, request_id=request_id)
|
|
|
+ resp = await target_peer.send_example(target_shard, step, target, length, request_id=request_id, train=train)
|
|
|
+ return resp
|
|
|
|
|
|
async def forward_prompt(
|
|
|
self,
|