فهرست منبع

run realize on the result in tinygrad

Alex Cheema 11 ماه پیش
والد
کامیت
e616d4e86b
1فایلهای تغییر یافته به همراه2 افزوده شده و 2 حذف شده
  1. 2 2
      exo/inference/tinygrad/inference.py

+ 2 - 2
exo/inference/tinygrad/inference.py

@@ -64,7 +64,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
     toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor([toks]), start_pos, TEMPERATURE)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
 
     if h.shape == (1,):
       start_pos += len(toks)
@@ -80,7 +80,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos, TEMPERATURE)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
 
     if h.shape == (1,):
       start_pos += n_captured_toks