|
@@ -64,7 +64,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
|
|
|
|
toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
|
|
- h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor([toks]), start_pos, TEMPERATURE)
|
|
|
+ h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
|
|
|
|
|
|
if h.shape == (1,):
|
|
|
start_pos += len(toks)
|
|
@@ -80,7 +80,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
|
|
start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
|
|
|
n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
|
|
|
|
|
|
- h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos, TEMPERATURE)
|
|
|
+ h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
|
|
|
|
|
|
if h.shape == (1,):
|
|
|
start_pos += n_captured_toks
|