فهرست منبع

add --default-temp option to change sample temperature

Alex Cheema 5 ماه پیش
والد
کامیت
c2647764b4
2فایلهای تغییر یافته به همراه6 افزوده شده و 2 حذف شده
  1. 3 1
      exo/main.py
  2. 3 1
      exo/orchestration/standard_node.py

+ 3 - 1
exo/main.py

@@ -55,6 +55,7 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
+parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 args = parser.parse_args()
@@ -119,7 +120,8 @@ node = StandardNode(
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   topology_viz=topology_viz,
-  shard_downloader=shard_downloader
+  shard_downloader=shard_downloader,
+  default_sample_temperature=args.default_temp
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server

+ 3 - 1
exo/orchestration/standard_node.py

@@ -27,6 +27,7 @@ class StandardNode(Node):
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
+    default_sample_temperature: float = 0.0,
     topology_viz: Optional[TopologyViz] = None,
     shard_downloader: Optional[HFShardDownloader] = None,
   ):
@@ -43,6 +44,7 @@ class StandardNode(Node):
     self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
+    self.default_sample_temperature = default_sample_temperature
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
@@ -112,7 +114,7 @@ class StandardNode(Node):
       self.buffered_token_output[request_id] = ([], False)
     is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
     if shard.is_last_layer() and not is_finished:
-      token = await self.inference_engine.sample(result)
+      token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
       await self.inference_engine.ensure_shard(shard)
       self.buffered_token_output[request_id][0].append(token.item())
       if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")