|
@@ -5,7 +5,7 @@ from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggin
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from tinygrad.nn.state import load_state_dict
|
|
|
-from tinygrad import Tensor, nn, Context
|
|
|
+from tinygrad import Tensor, nn, Context, TinyJit
|
|
|
from exo.inference.inference_engine import InferenceEngine
|
|
|
import numpy as np
|
|
|
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
|
|
@@ -15,7 +15,7 @@ from .stateful_model import StatefulModel
|
|
|
from .losses import length_masked_ce_loss
|
|
|
import asyncio
|
|
|
|
|
|
-Tensor.no_grad = True
|
|
|
+Tensor.no_grad = False
|
|
|
# default settings
|
|
|
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
|
|
|
TOP_K = 25
|
|
@@ -63,6 +63,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
self.shard = None
|
|
|
self.shard_downloader = shard_downloader
|
|
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
|
|
+ self.session = {}
|
|
|
|
|
|
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
|
|
logits = x[:, -1, :]
|
|
@@ -82,11 +83,37 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
|
|
|
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
|
|
await self.ensure_shard(shard)
|
|
|
- #print(f"infer_tensor in <- {input_data}")
|
|
|
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
|
|
|
- #print(f"infer_tensor out -> {output_data}")
|
|
|
return output_data.numpy()
|
|
|
|
|
|
+ async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
|
|
+ def step(x, y, l):
|
|
|
+ Tensor.training = False
|
|
|
+ return self.session['loss'](self.model, x, y, l)
|
|
|
+ await self.ensure_shard(shard)
|
|
|
+ await self.ensure_session('loss', lambda: loss)
|
|
|
+ await self.ensure_session('jit', lambda: TinyJit(step))
|
|
|
+ score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
|
|
|
+ out = score.numpy()
|
|
|
+ return out
|
|
|
+
|
|
|
+ async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=nn.optim.Adam, lr=1e-5):
|
|
|
+ def step(x, y, l):
|
|
|
+ Tensor.training = True
|
|
|
+ score = self.session['loss'](self.model, x, y, l)
|
|
|
+ self.session['opt'].zero_grad()
|
|
|
+ score.backward()
|
|
|
+ self.session['opt'].step()
|
|
|
+ return score
|
|
|
+ await self.ensure_shard(shard)
|
|
|
+ await self.ensure_session('loss', lambda: loss)
|
|
|
+ await self.ensure_session('opt', lambda: opt(nn.state.get_parameters(self.model.model), lr=lr))
|
|
|
+ await self.ensure_session('jit', lambda: TinyJit(step))
|
|
|
+
|
|
|
+ score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())
|
|
|
+
|
|
|
+ return loss.numpy(), loss.numpy()
|
|
|
+
|
|
|
async def ensure_shard(self, shard: Shard):
|
|
|
if self.shard == shard:
|
|
|
return
|
|
@@ -101,13 +128,4 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
|
|
|
self.tokenizer = await resolve_tokenizer(tokenizer_path)
|
|
|
self.shard = shard
|
|
|
- self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)
|
|
|
-
|
|
|
- async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
|
|
- await self.ensure_shard(shard)
|
|
|
- def model_wrapper(x):
|
|
|
- return self.model(x, request_id)
|
|
|
- score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: loss(model_wrapper, Tensor(inputs), Tensor(targets), Tensor(lengths)).realize())
|
|
|
- out = score.numpy()
|
|
|
- return out
|
|
|
-
|
|
|
+ self.model = await loop.run_in_executor(self.executor, StatefulModel, model_shard)
|