瀏覽代碼

experiment with tinygrad on its own thread, so it doesnt block event loop

Alex Cheema 8 月之前
父節點
當前提交
2950373d36
共有 1 個文件被更改,包括 32 次插入5 次删除
  1. 32 5
      exo/inference/tinygrad/inference.py

+ 32 - 5
exo/inference/tinygrad/inference.py

@@ -12,6 +12,10 @@ from typing import Optional, Tuple
 import numpy as np
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
 from exo.download.shard_download import ShardDownloader
+from concurrent.futures import ThreadPoolExecutor
+import asyncio
+import threading
+from functools import partial
 
 
 Tensor.no_grad = True
 Tensor.no_grad = True
 # default settings
 # default settings
@@ -52,6 +56,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   def __init__(self, shard_downloader: ShardDownloader):
   def __init__(self, shard_downloader: ShardDownloader):
     self.shard = None
     self.shard = None
     self.shard_downloader = shard_downloader
     self.shard_downloader = shard_downloader
+    self.model_lock = threading.Lock()
+    self.executor = ThreadPoolExecutor(max_workers=1)
 
 
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
   async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
     await self.ensure_shard(shard)
     await self.ensure_shard(shard)
@@ -59,7 +65,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
 
     toks = self.tokenizer.encode(prompt)
     toks = self.tokenizer.encode(prompt)
-    h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
+
+    h = await self._run_inference(Tensor([toks]), start_pos)
 
 
     if h.shape == (1,):
     if h.shape == (1,):
       start_pos += len(toks)
       start_pos += len(toks)
@@ -75,7 +82,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
 
-    h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
+    h = await self._run_inference(Tensor(input_data), start_pos)
 
 
     if h.shape == (1,):
     if h.shape == (1,):
       start_pos += n_captured_toks
       start_pos += n_captured_toks
@@ -85,11 +92,31 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     else:
     else:
       return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
       return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
 
 
+  async def _run_inference(self, input_tensor, start_pos):
+    with self.model_lock:
+      return await asyncio.get_event_loop().run_in_executor(
+        self.executor,
+        self.model,
+        input_tensor,
+        start_pos,
+        TEMPERATURE
+      )
+
   async def ensure_shard(self, shard: Shard):
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:
     if self.shard == shard:
       return
       return
 
 
     model_path = await self.shard_downloader.ensure_shard(shard)
     model_path = await self.shard_downloader.ensure_shard(shard)
-    self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
-    self.tokenizer = await resolve_tokenizer(str((model_path if model_path.is_dir() else model_path.parent)))
-    self.shard = shard
+
+    with self.model_lock:
+      if self.shard != shard:
+        self.model = await asyncio.get_event_loop().run_in_executor(
+          self.executor,
+          build_transformer,
+          model_path,
+          shard,
+          "8B" if "8b" in shard.model_id.lower() else "70B"
+        )
+        tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
+        self.tokenizer = await resolve_tokenizer(tokenizer_path)
+        self.shard = shard