1
0
Эх сурвалжийг харах

run realize on the result in tinygrad

Alex Cheema 11 сар өмнө
parent
commit
e616d4e86b

+ 2 - 2
exo/inference/tinygrad/inference.py

@@ -64,7 +64,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
     toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor([toks]), start_pos, TEMPERATURE)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
 
     if h.shape == (1,):
       start_pos += len(toks)
@@ -80,7 +80,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos, TEMPERATURE)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
 
     if h.shape == (1,):
       start_pos += n_captured_toks