|
@@ -27,6 +27,7 @@ class StandardNode(Node):
|
|
|
discovery: Discovery,
|
|
|
partitioning_strategy: PartitioningStrategy = None,
|
|
|
max_generate_tokens: int = 1024,
|
|
|
+ default_sample_temperature: float = 0.0,
|
|
|
topology_viz: Optional[TopologyViz] = None,
|
|
|
shard_downloader: Optional[HFShardDownloader] = None,
|
|
|
):
|
|
@@ -43,6 +44,7 @@ class StandardNode(Node):
|
|
|
self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
|
|
|
self.max_generate_tokens = max_generate_tokens
|
|
|
self.topology_viz = topology_viz
|
|
|
+ self.default_sample_temperature = default_sample_temperature
|
|
|
self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
|
|
|
self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
|
|
|
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
|
|
@@ -112,7 +114,7 @@ class StandardNode(Node):
|
|
|
self.buffered_token_output[request_id] = ([], False)
|
|
|
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
|
|
if shard.is_last_layer() and not is_finished:
|
|
|
- token = await self.inference_engine.sample(result)
|
|
|
+ token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
|
|
|
await self.inference_engine.ensure_shard(shard)
|
|
|
self.buffered_token_output[request_id][0].append(token.item())
|
|
|
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|