Browse Source

Merge pull request #487 from roryclear/update-tg

update tinygrad version
Alex Cheema 6 tháng trước cách đây
mục cha
commit
bb9a91906b
3 tập tin đã thay đổi với 5 bổ sung6 xóa
  1. 1 2
      exo/inference/tinygrad/inference.py
  2. 3 3
      exo/inference/tinygrad/models/llama.py
  3. 1 1
      setup.py

+ 1 - 2
exo/inference/tinygrad/inference.py

@@ -40,8 +40,7 @@ MODEL_PARAMS = {
 def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
   # build model
   linear = nn.Linear
-  with Context(THREEFRY=0):
-    model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
+  model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
 
   # load weights
   if model_path.is_dir():

+ 3 - 3
exo/inference/tinygrad/models/llama.py

@@ -225,9 +225,9 @@ class Transformer:
       h = inputs
     return h
 
-  def forward(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
-    if x.shape[0:2] == (1, 1) and self.forward_jit is not None:
-      return self.forward_jit(x, Variable("start_pos", 0, self.max_context).bind(start_pos), cache=cache)
+  def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
+    if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
+      return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
     return self.forward_base(x, start_pos, cache=cache)
 
   def __call__(self, tokens: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):

+ 1 - 1
setup.py

@@ -26,7 +26,7 @@ install_requires = [
   "tqdm==4.66.4",
   "transformers==4.46.3",
   "uuid==1.30",
-  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
+  "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
 ]
 
 extras_require = {