|
@@ -122,10 +122,9 @@ class StandardNode(Node):
|
|
for i in np.reshape(result, (-1, 1, result.shape[-1])):
|
|
for i in np.reshape(result, (-1, 1, result.shape[-1])):
|
|
self.buffered_logits[request_id][0].append(i)
|
|
self.buffered_logits[request_id][0].append(i)
|
|
|
|
|
|
- inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id][0])})
|
|
|
|
-
|
|
|
|
if shard.is_last_layer():
|
|
if shard.is_last_layer():
|
|
result = await self.inference_engine.sample(result)
|
|
result = await self.inference_engine.sample(result)
|
|
|
|
+ inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id][0]) + 1})
|
|
|
|
|
|
await self.inference_engine.ensure_shard(shard)
|
|
await self.inference_engine.ensure_shard(shard)
|
|
is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
@@ -134,7 +133,6 @@ class StandardNode(Node):
|
|
|
|
|
|
if result.size == 1: # we got a new token out
|
|
if result.size == 1: # we got a new token out
|
|
self.buffered_token_output[request_id][0].append(result.item())
|
|
self.buffered_token_output[request_id][0].append(result.item())
|
|
- inference_state = json.dumps({"start_pos": json.loads(inference_state or "{}").get("start_pos", 0) + 1})
|
|
|
|
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
|
|
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|