瀏覽代碼

simplify tinygrad non blocking

Alex Cheema 8 月之前
父節點
當前提交
4ec613d4e8
共有 1 個文件被更改,包括 10 次插入14 次删除
  1. 10 14
      exo/inference/tinygrad/inference.py

+ 10 - 14
exo/inference/tinygrad/inference.py

@@ -56,7 +56,6 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard_downloader = shard_downloader
-    self.model_lock = threading.Lock()
     self.executor = ThreadPoolExecutor(max_workers=1)
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
@@ -64,9 +63,10 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    toks = self.tokenizer.encode(prompt)
+    toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
+    input_tensor = Tensor([toks])
 
-    h = await self._run_inference(Tensor([toks]), start_pos)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
 
     if h.shape == (1,):
       start_pos += len(toks)
@@ -82,7 +82,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    h = await self._run_inference(Tensor(input_data), start_pos)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos, TEMPERATURE)
 
     if h.shape == (1,):
       start_pos += n_captured_toks
@@ -92,19 +92,15 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     else:
       return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
 
-  async def _run_inference(self, input_tensor, start_pos):
-    with self.model_lock:
-      return await asyncio.get_event_loop().run_in_executor(self.executor, self.model, input_tensor, start_pos, TEMPERATURE)
-
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
       return
 
     model_path = await self.shard_downloader.ensure_shard(shard)
 
-    with self.model_lock:
-      if self.shard != shard:
-        self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
-        tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
-        self.tokenizer = await resolve_tokenizer(tokenizer_path)
-        self.shard = shard
+    if self.shard != shard:
+      self.model = await asyncio.get_event_loop().run_in_executor(self.executor, build_transformer, model_path, shard, "8B" if "8b" in shard.model_id.lower() else "70B")
+
+      tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
+      self.tokenizer = await resolve_tokenizer(tokenizer_path)
+      self.shard = shard