Parcourir la source

I think this is more faithful to how it was originally done

Nel Nibcord il y a 8 mois
Parent
commit
aefc0d7c51
1 fichiers modifiés avec 1 ajouts et 3 suppressions
  1. 1 3
      exo/orchestration/standard_node.py

+ 1 - 3
exo/orchestration/standard_node.py

@@ -122,10 +122,9 @@ class StandardNode(Node):
     for i in np.reshape(result, (-1, 1, result.shape[-1])):
       self.buffered_logits[request_id][0].append(i)
 
-    inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id][0])})
-
     if shard.is_last_layer():
       result = await self.inference_engine.sample(result)
+      inference_state = json.dumps({"start_pos": len(self.buffered_logits[request_id][0]) + 1})
     
     await self.inference_engine.ensure_shard(shard)
     is_finished = result.size == 1 and result.item() == self.inference_engine.tokenizer.eos_token_id or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
@@ -134,7 +133,6 @@ class StandardNode(Node):
 
     if result.size == 1:  # we got a new token out
       self.buffered_token_output[request_id][0].append(result.item())
-      inference_state = json.dumps({"start_pos": json.loads(inference_state or "{}").get("start_pos", 0) + 1})
       self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
     
     if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")