Browse Source

Merge pull request #472 from exo-explore/pyver

Update some versions to support Python >= 3.9 and fix tinygrad thread issues
Alex Cheema 5 months ago
parent
commit
2dafa9cc65
3 changed files with 9 additions and 14 deletions
  1. 1 1
      exo/download/hf/hf_helpers.py
  2. 5 9
      exo/inference/tinygrad/inference.py
  3. 3 4
      setup.py

+ 1 - 1
exo/download/hf/hf_helpers.py

@@ -409,7 +409,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
     elif shard.is_last_layer():
       shard_specific_patterns.add(sorted_file_names[-1])
   else:
-    shard_specific_patterns = set("*.safetensors")
+    shard_specific_patterns = set(["*.safetensors"])
   if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
   return list(default_patterns | shard_specific_patterns)
 

+ 5 - 9
exo/inference/tinygrad/inference.py

@@ -7,7 +7,6 @@ from exo.inference.tokenizers import resolve_tokenizer
 from tinygrad.nn.state import load_state_dict
 from tinygrad import Tensor, nn, Context
 from exo.inference.inference_engine import InferenceEngine
-from typing import Optional, Tuple
 import numpy as np
 from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
 from exo.download.shard_download import ShardDownloader
@@ -68,24 +67,21 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
   async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
     logits = x[:, -1, :]
     def sample_wrapper():
-      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize()
-    out = await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
-    return out.numpy().astype(int)
+      return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
 
   async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
     await self.ensure_shard(shard)
     tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
-    return np.array(tokens)
+    return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
   
   async def decode(self, shard: Shard, tokens) -> str:
     await self.ensure_shard(shard)
-    tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
-    return tokens
+    return await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
 
   async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
     await self.ensure_shard(shard)
-    output_data = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize())
-    return output_data.numpy()
+    return await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.model(Tensor(input_data), request_id).realize().numpy())
 
   async def ensure_shard(self, shard: Shard):
     if self.shard == shard:

+ 3 - 4
setup.py

@@ -8,8 +8,8 @@ install_requires = [
   "aiohttp==3.10.11",
   "aiohttp_cors==0.7.0",
   "aiofiles==24.1.0",
-  "grpcio==1.64.1",
-  "grpcio-tools==1.64.1",
+  "grpcio==1.68.0",
+  "grpcio-tools==1.68.0",
   "Jinja2==3.1.4",
   "netifaces==0.11.0",
   "numpy==2.0.0",
@@ -21,10 +21,9 @@ install_requires = [
   "pydantic==2.9.2",
   "requests==2.32.3",
   "rich==13.7.1",
-  "safetensors==0.4.3",
   "tenacity==9.0.0",
   "tqdm==4.66.4",
-  "transformers==4.43.3",
+  "transformers==4.46.3",
   "uuid==1.30",
   "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@232edcfd4f8b388807c64fb1817a7668ce27cbad",
 ]