Browse Source

Merge pull request #512 from exo-explore/sampletemp

add --default-temp option to change sample temperature
Alex Cheema 7 months ago
parent
commit
0211cb0b15
3 changed files with 7 additions and 3 deletions
  1. 1 1
      exo/inference/dummy_inference_engine.py
  2. 3 1
      exo/main.py
  3. 3 1
      exo/orchestration/standard_node.py

+ 1 - 1
exo/inference/dummy_inference_engine.py

@@ -18,7 +18,7 @@ class DummyInferenceEngine(InferenceEngine):
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     return np.array(self.tokenizer.encode(prompt))
   
-  async def sample(self, x: np.ndarray) -> np.ndarray:
+  async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
     if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
     return x
 

+ 3 - 1
exo/main.py

@@ -55,6 +55,7 @@ parser.add_argument("--inference-engine", type=str, default=None, help="Inferenc
 parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
 parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
 parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
+parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
 parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
 parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
 args = parser.parse_args()
@@ -119,7 +120,8 @@ node = StandardNode(
   partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
   max_generate_tokens=args.max_generate_tokens,
   topology_viz=topology_viz,
-  shard_downloader=shard_downloader
+  shard_downloader=shard_downloader,
+  default_sample_temperature=args.default_temp
 )
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server

+ 3 - 1
exo/orchestration/standard_node.py

@@ -27,6 +27,7 @@ class StandardNode(Node):
     discovery: Discovery,
     partitioning_strategy: PartitioningStrategy = None,
     max_generate_tokens: int = 1024,
+    default_sample_temperature: float = 0.0,
     topology_viz: Optional[TopologyViz] = None,
     shard_downloader: Optional[HFShardDownloader] = None,
   ):
@@ -43,6 +44,7 @@ class StandardNode(Node):
     self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
     self.max_generate_tokens = max_generate_tokens
     self.topology_viz = topology_viz
+    self.default_sample_temperature = default_sample_temperature
     self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
     self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
     self._on_opaque_status.register("node_status").on_next(self.on_node_status)
@@ -112,7 +114,7 @@ class StandardNode(Node):
       self.buffered_token_output[request_id] = ([], False)
     is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
     if shard.is_last_layer() and not is_finished:
-      token = await self.inference_engine.sample(result)
+      token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
       await self.inference_engine.ensure_shard(shard)
       self.buffered_token_output[request_id][0].append(token.item())
       if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")