|
@@ -56,7 +56,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
|
self.shard = None
|
|
|
self.shard_downloader = shard_downloader
|
|
|
- self.model_lock = threading.Lock()
|
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
|
|
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
|
@@ -64,9 +63,10 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
|
|
|
|
- toks = self.tokenizer.encode(prompt)
|
|
|
+ toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
|
|
+ input_tensor = Tensor([toks])
|
|
|
|
|
|
- h = await self._run_inference(Tensor([toks]), start_pos)
|
|
|
+ h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
|
|
|
|
|
|
if h.shape == (1,):
|
|
|
start_pos += len(toks)
|
|
@@ -82,7 +82,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
|
|
|
|
- h = await self._run_inference(Tensor(input_data), start_pos)
|
|
|
+ h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos, TEMPERATURE)
|
|
|
|
|
|
if h.shape == (1,):
|
|
|
start_pos += n_captured_toks
|
|
@@ -92,19 +92,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
else:
|
|
|
return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
|
|
|
|
|
|
- async def _run_inference(self, input_tensor, start_pos):
|
|
|
- with self.model_lock:
|
|
|
- return await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
|
|
|
-
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
|
if self.shard == shard:
|
|
|
return
|
|
|
|
|
|
model_path = await self.shard_downloader.ensure_shard(shard)
|
|
|
|
|
|
- with self.model_lock:
|
|
|
- if self.shard != shard:
|
|
|
- self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
|
|
|
- tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
|
|
|
- self.tokenizer = await resolve_tokenizer(tokenizer_path)
|
|
|
- self.shard = shard
|
|
|
+ if self.shard != shard:
|
|
|
+ self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
|
|
|
+
|
|
|
+ tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
|
|
|
+ self.tokenizer = await resolve_tokenizer(tokenizer_path)
|
|
|
+ self.shard = shard
|