|
@@ -53,6 +53,7 @@ class StandardNode(Node):
|
|
|
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
|
|
self.topology_inference_engines_pool: List[List[str]] = []
|
|
|
self.shard_downloader = shard_downloader
|
|
|
+ self.outstanding_requests = {}
|
|
|
|
|
|
async def start(self, wait_for_peers: int = 0) -> None:
|
|
|
await self.server.start()
|
|
@@ -119,8 +120,10 @@ class StandardNode(Node):
|
|
|
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
|
|
|
await self.inference_engine.ensure_shard(shard)
|
|
|
self.buffered_token_output[request_id][0].append(token.item())
|
|
|
+ is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
+ self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
|
- is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
|
|
|
+ asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
|
|
|
forward = token.reshape(1, -1)
|
|
|
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
|
|
@@ -129,7 +132,9 @@ class StandardNode(Node):
|
|
|
|
|
|
if is_finished:
|
|
|
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
|
|
+ self.outstanding_requests.pop(request_id)
|
|
|
else:
|
|
|
+ self.outstanding_requests[request_id] = "waiting"
|
|
|
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
|
|
|
|
|
|
return np.array(self.buffered_token_output[request_id][0])
|
|
@@ -185,11 +190,13 @@ class StandardNode(Node):
|
|
|
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
|
|
if not shard.is_first_layer():
|
|
|
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
|
|
+ self.outstanding_requests[request_id] = "waiting"
|
|
|
resp = await self.forward_prompt(shard, prompt, request_id, 0)
|
|
|
return None
|
|
|
else:
|
|
|
+ self.outstanding_requests[request_id] = "processing"
|
|
|
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
|
|
- ret = await self.process_inference_result(shard, result, request_id)
|
|
|
+ ret = await self.process_inference_result(shard, result, request_id)
|
|
|
return result
|
|
|
|
|
|
async def enqueue_example(
|
|
@@ -207,6 +214,7 @@ class StandardNode(Node):
|
|
|
else:
|
|
|
if request_id is None:
|
|
|
request_id = str(uuid.uuid4())
|
|
|
+ self.outstanding_requests[request_id] = "waiting"
|
|
|
resp = await self.forward_example(shard, example, target, length, train, request_id, 0)
|
|
|
return resp
|
|
|
|
|
@@ -274,20 +282,30 @@ class StandardNode(Node):
|
|
|
target = target.astype(int)
|
|
|
if train:
|
|
|
if shard.is_last_layer():
|
|
|
+ self.outstanding_requests[request_id] = "training"
|
|
|
loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
|
|
|
else:
|
|
|
+ self.outstanding_requests[request_id] = "preprocessing"
|
|
|
step = await self.inference_engine.infer_tensor(request_id, shard, example)
|
|
|
+ self.outstanding_requests[request_id] = "waiting"
|
|
|
backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
|
|
|
+ self.outstanding_requests[request_id] = "training"
|
|
|
loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
|
|
|
+ self.outstanding_requests.pop(request_id)
|
|
|
return loss.reshape(example.shape[0], -1) if shard.is_first_layer() else grad
|
|
|
else:
|
|
|
if shard.is_last_layer():
|
|
|
+ self.outstanding_requests[request_id] = "evaluating"
|
|
|
loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
|
|
|
else:
|
|
|
+ self.outstanding_requests[request_id] = "preprocessing"
|
|
|
step = await self.inference_engine.infer_tensor(request_id, shard, example)
|
|
|
+ self.outstanding_requests[request_id] = "waiting"
|
|
|
loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
|
|
|
+ self.outstanding_requests.pop(request_id)
|
|
|
return loss.reshape(example.shape[0], -1)
|
|
|
except Exception as e:
|
|
|
+ self.outstanding_requests.pop(request_id)
|
|
|
print(f"Error processing example for shard {shard}: {e}")
|
|
|
traceback.print_exc()
|
|
|
return None
|
|
@@ -347,10 +365,12 @@ class StandardNode(Node):
|
|
|
|
|
|
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
|
|
try:
|
|
|
+ self.outstanding_requests[request_id] = "processing"
|
|
|
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
|
|
ret = await self.process_inference_result(shard, result, request_id)
|
|
|
return ret
|
|
|
except Exception as e:
|
|
|
+ self.outstanding_requests.pop(request_id)
|
|
|
print(f"Error processing tensor for shard {shard}: {e}")
|
|
|
traceback.print_exc()
|
|
|
return None
|