|
@@ -12,6 +12,10 @@ from typing import Optional, Tuple
|
|
import numpy as np
|
|
import numpy as np
|
|
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
|
|
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
|
|
from exo.download.shard_download import ShardDownloader
|
|
from exo.download.shard_download import ShardDownloader
|
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
|
+import asyncio
|
|
|
|
+import threading
|
|
|
|
+from functools import partial
|
|
|
|
|
|
Tensor.no_grad = True
|
|
Tensor.no_grad = True
|
|
# default settings
|
|
# default settings
|
|
@@ -52,6 +56,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
def __init__(self, shard_downloader: ShardDownloader):
|
|
self.shard = None
|
|
self.shard = None
|
|
self.shard_downloader = shard_downloader
|
|
self.shard_downloader = shard_downloader
|
|
|
|
+ self.model_lock = threading.Lock()
|
|
|
|
+ self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
|
|
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
|
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, image_str: Optional[str] = None, inference_state: Optional[str] = None) -> (np.ndarray, str, bool):
|
|
await self.ensure_shard(shard)
|
|
await self.ensure_shard(shard)
|
|
@@ -59,7 +65,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
|
|
|
|
toks = self.tokenizer.encode(prompt)
|
|
toks = self.tokenizer.encode(prompt)
|
|
- h = self.model(Tensor([toks]), start_pos, TEMPERATURE).realize()
|
|
|
|
|
|
+
|
|
|
|
+ h = await self._run_inference(Tensor([toks]), start_pos)
|
|
|
|
|
|
if h.shape == (1,):
|
|
if h.shape == (1,):
|
|
start_pos += len(toks)
|
|
start_pos += len(toks)
|
|
@@ -75,7 +82,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
|
|
|
|
- h = self.model(Tensor(input_data), start_pos, TEMPERATURE).realize()
|
|
|
|
|
|
+ h = await self._run_inference(Tensor(input_data), start_pos)
|
|
|
|
|
|
if h.shape == (1,):
|
|
if h.shape == (1,):
|
|
start_pos += n_captured_toks
|
|
start_pos += n_captured_toks
|
|
@@ -85,11 +92,31 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
else:
|
|
else:
|
|
return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
|
|
return h.numpy(), json.dumps({"start_pos": start_pos, "n_captured_toks": n_captured_toks}), False
|
|
|
|
|
|
|
|
+ async def _run_inference(self, input_tensor, start_pos):
|
|
|
|
+ with self.model_lock:
|
|
|
|
+ return await asyncio.get_event_loop().run_in_executor(
|
|
|
|
+ self.executor,
|
|
|
|
+ self.model,
|
|
|
|
+ input_tensor,
|
|
|
|
+ start_pos,
|
|
|
|
+ TEMPERATURE
|
|
|
|
+ )
|
|
|
|
+
|
|
async def ensure_shard(self, shard: Shard):
|
|
async def ensure_shard(self, shard: Shard):
|
|
if self.shard == shard:
|
|
if self.shard == shard:
|
|
return
|
|
return
|
|
|
|
|
|
model_path = await self.shard_downloader.ensure_shard(shard)
|
|
model_path = await self.shard_downloader.ensure_shard(shard)
|
|
- self.model = build_transformer(model_path, shard, model_size="8B" if "8b" in shard.model_id.lower() else "70B")
|
|
|
|
- self.tokenizer = await resolve_tokenizer(str((model_path if model_path.is_dir() else model_path.parent)))
|
|
|
|
- self.shard = shard
|
|
|
|
|
|
+
|
|
|
|
+ with self.model_lock:
|
|
|
|
+ if self.shard != shard:
|
|
|
|
+ self.model = await asyncio.get_event_loop().run_in_executor(
|
|
|
|
+ self.executor,
|
|
|
|
+ build_transformer,
|
|
|
|
+ model_path,
|
|
|
|
+ shard,
|
|
|
|
+ "8B" if "8b" in shard.model_id.lower() else "70B"
|
|
|
|
+ )
|
|
|
|
+ tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
|
|
|
|
+ self.tokenizer = await resolve_tokenizer(tokenizer_path)
|
|
|
|
+ self.shard = shard
|