Parcourir la source

run realize on the result in tinygrad

Alex Cheema il y a 11 mois
Parent
commit
e616d4e86b
1 fichiers modifiés avec 2 ajouts et 2 suppressions
  1. 2 2
      exo/inference/tinygrad/inference.py

+ 2 - 2
exo/inference/tinygrad/inference.py

@@ -64,7 +64,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
     toks = await asyncio.get_event_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor([toks]), start_pos, TEMPERATURE)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor([toks]), start_pos, TEMPERATURE).realize())
 
     if h.shape == (1,):
       start_pos += len(toks)
@@ -80,7 +80,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
     start_pos = json.loads(inference_state or "{}").get("start_pos", 0)
     n_captured_toks = json.loads(inference_state or "{}").get("n_captured_toks", 0)
 
-    h = await asyncio.get_event_loop().run_in_executor(self.executor, self.model, Tensor(input_data), start_pos, TEMPERATURE)
+    h = await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), start_pos, TEMPERATURE).realize())
 
     if h.shape == (1,):
       start_pos += n_captured_toks