tinygrad_helpers.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from tinygrad.nn.state import safe_load, torch_load
  2. from tinygrad import Tensor
  3. from pathlib import Path
  4. import json
  5. from typing import List
  6. from exo.inference.shard import Shard
  7. from exo.helpers import DEBUG
  8. from exo.download.hf.hf_helpers import get_allow_patterns
  9. from fnmatch import fnmatch
  10. # **** helper functions ****
  11. def concat_weights(models, device=None):
  12. def convert(name) -> Tensor:
  13. disk_tensors: List[Tensor] = [model[name] for model in models]
  14. if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
  15. return disk_tensors[0].to(device=device)
  16. axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
  17. lazy_tensors = [data.to(device=device) for data in disk_tensors]
  18. return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
  19. return {name: convert(name) for name in {name: None for model in models for name in model}}
  20. def load(fn:str, shard: Shard):
  21. if fn.endswith('.index.json'):
  22. with open(fn) as fp: weight_map = json.load(fp)['weight_map']
  23. parts = {}
  24. filtered_weight_map = {}
  25. allow_patterns = get_allow_patterns(weight_map, shard)
  26. for k, n in weight_map.items():
  27. if allow_patterns is not None and not any(fnmatch(n, r) for r in allow_patterns):
  28. continue
  29. if k.startswith("model.layers."):
  30. layer_num = int(k.split('.')[2])
  31. if layer_num < shard.start_layer or layer_num > shard.end_layer:
  32. continue
  33. parts[n] = load(str(Path(fn).parent / Path(n).name), shard)
  34. filtered_weight_map[k] = n
  35. if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
  36. return {k: parts[n][k] for k, n in filtered_weight_map.items()}
  37. elif fn.endswith(".safetensors"):
  38. return safe_load(fn)
  39. else:
  40. return torch_load(fn)