|
@@ -1,7 +1,7 @@
|
|
|
from pathlib import Path
|
|
|
import json
|
|
|
import os
|
|
|
-from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
|
|
+from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16, sample_logits
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from tinygrad.nn.state import load_state_dict
|
|
@@ -65,37 +65,33 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.shard_downloader = shard_downloader
|
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
|
|
- async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
|
|
- await self.ensure_shard(shard)
|
|
|
- start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
|
- n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
|
+ async def sample(self, x: np.ndarray):
|
|
|
+ logits = x[:, -1, :]
|
|
|
+ def sample_wrapper():
|
|
|
+ return sample_logits(Tensor(x).flatten(), TEMPERATURE, 0, 0.8, 0.0, 0.0).realize()
|
|
|
+ out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
|
|
+ return out.numpy()
|
|
|
|
|
|
- toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
|
|
- h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
|
|
|
+ async def encode(self, shard: Shard, prompt: str):
|
|
|
+ await self.ensure_shard(shard)
|
|
|
+ tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
|
|
+ return tokens
|
|
|
+
|
|
|
+ async def decode(self, shard: Shard, tokens):
|
|
|
+ await self.ensure_shard(shard)
|
|
|
+ tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
|
|
+ return tokens
|
|
|
|
|
|
- if h.shape == (1,):
|
|
|
- start_pos += len(toks)
|
|
|
- start_pos += 1
|
|
|
- n_captured_toks = 0
|
|
|
- return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
|
|
|
- else:
|
|
|
- n_captured_toks = len(toks)
|
|
|
- return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
|
|
|
+ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> np.ndarray:
|
|
|
+ tokens = await self.encode(shard, prompt)
|
|
|
+ output_data = await self.infer_tensor(request_id, shard, tokens, inference_state)
|
|
|
+ return output_data
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[str] = None) -> tuple[np.ndarray, str, bool]:
|
|
|
await self.ensure_shard(shard)
|
|
|
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
|
- n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
|
-
|
|
|
- h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
|
|
|
-
|
|
|
- if h.shape == (1,):
|
|
|
- start_pos += n_captured_toks
|
|
|
- start_pos += 1
|
|
|
- n_captured_toks = 0
|
|
|
- return np.array([[h.item()]]), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), h.item() == self.tokenizer.eos_token_id
|
|
|
- else:
|
|
|
- return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
|
|
|
+ output_data = await asyncio.get_running_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos)
|
|
|
+ return output_data.numpy()
|
|
|
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
|
if self.shard == shard:
|