models.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from exo.inference.shard import Shard
  2. from typing import Optional, List
  3. model_cards = {
  4. ### llama
  5. "llama-3.3-70b": {
  6. "layers": 80,
  7. "repo": {
  8. "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.3-70B-Instruct-4bit",
  9. "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.3-70B-Instruct",
  10. },
  11. },
  12. "llama-3.2-1b": {
  13. "layers": 16,
  14. "repo": {
  15. "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
  16. "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
  17. },
  18. },
  19. "llama-3.2-1b-8bit": {
  20. "layers": 16,
  21. "repo": {
  22. "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-8bit",
  23. "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
  24. },
  25. },
  26. "llama-3.2-3b": {
  27. "layers": 28,
  28. "repo": {
  29. "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
  30. "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
  31. },
  32. },
  33. "llama-3.2-3b-8bit": {
  34. "layers": 28,
  35. "repo": {
  36. "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-8bit",
  37. "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
  38. },
  39. },
  40. "llama-3.2-3b-bf16": {
  41. "layers": 28,
  42. "repo": {
  43. "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct",
  44. "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
  45. },
  46. },
  47. "llama-3.1-8b": {
  48. "layers": 32,
  49. "repo": {
  50. "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
  51. "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
  52. },
  53. },
  54. "llama-3.1-70b": {
  55. "layers": 80,
  56. "repo": {
  57. "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
  58. "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
  59. },
  60. },
  61. "llama-3.1-70b-bf16": {
  62. "layers": 80,
  63. "repo": {
  64. "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
  65. "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
  66. },
  67. },
  68. "llama-3-8b": {
  69. "layers": 32,
  70. "repo": {
  71. "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
  72. "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
  73. },
  74. },
  75. "llama-3-70b": {
  76. "layers": 80,
  77. "repo": {
  78. "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
  79. "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
  80. },
  81. },
  82. "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
  83. "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
  84. ### mistral
  85. "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
  86. "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
  87. ### deepseek
  88. "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
  89. "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
  90. ### llava
  91. "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
  92. ### qwen
  93. "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
  94. "qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
  95. "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
  96. "qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
  97. "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
  98. "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
  99. "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
  100. "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
  101. "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
  102. "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
  103. "qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
  104. "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
  105. "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
  106. "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
  107. ### nemotron
  108. "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
  109. "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
  110. # gemma
  111. "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
  112. "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
  113. # dummy
  114. "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
  115. }
  116. pretty_name = {
  117. "llama-3.3-70b": "Llama 3.3 70B",
  118. "llama-3.2-1b": "Llama 3.2 1B",
  119. "llama-3.2-1b-8bit": "Llama 3.2 1B (8-bit)",
  120. "llama-3.2-3b": "Llama 3.2 3B",
  121. "llama-3.2-3b-8bit": "Llama 3.2 3B (8-bit)",
  122. "llama-3.2-3b-bf16": "Llama 3.2 3B (BF16)",
  123. "llama-3.1-8b": "Llama 3.1 8B",
  124. "llama-3.1-70b": "Llama 3.1 70B",
  125. "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
  126. "llama-3.1-405b": "Llama 3.1 405B",
  127. "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
  128. "gemma2-9b": "Gemma2 9B",
  129. "gemma2-27b": "Gemma2 27B",
  130. "nemotron-70b": "Nemotron 70B",
  131. "nemotron-70b-bf16": "Nemotron 70B (BF16)",
  132. "mistral-nemo": "Mistral Nemo",
  133. "mistral-large": "Mistral Large",
  134. "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
  135. "deepseek-coder-v2.5": "Deepseek Coder V2.5",
  136. "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
  137. "qwen-2.5-1.5b": "Qwen 2.5 1.5B",
  138. "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
  139. "qwen-2.5-3b": "Qwen 2.5 3B",
  140. "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
  141. "qwen-2.5-7b": "Qwen 2.5 7B",
  142. "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
  143. "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
  144. "qwen-2.5-14b": "Qwen 2.5 14B",
  145. "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
  146. "qwen-2.5-32b": "Qwen 2.5 32B",
  147. "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
  148. "qwen-2.5-72b": "Qwen 2.5 72B",
  149. "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
  150. "llama-3-8b": "Llama 3 8B",
  151. "llama-3-70b": "Llama 3 70B",
  152. }
  153. def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
  154. return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
  155. def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
  156. repo = get_repo(model_id, inference_engine_classname)
  157. n_layers = model_cards.get(model_id, {}).get("layers", 0)
  158. if repo is None or n_layers < 1:
  159. return None
  160. return Shard(model_id, 0, 0, n_layers)
  161. def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
  162. if not supported_inference_engine_lists:
  163. return list(model_cards.keys())
  164. from exo.inference.inference_engine import inference_engine_classes
  165. supported_inference_engine_lists = [
  166. [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
  167. for engine_list in supported_inference_engine_lists
  168. ]
  169. def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
  170. return any(engine in model_info.get("repo", {}) for engine in engine_list)
  171. def supports_all_engine_lists(model_info: dict) -> bool:
  172. return all(has_any_engine(model_info, engine_list)
  173. for engine_list in supported_inference_engine_lists)
  174. return [
  175. model_id for model_id, model_info in model_cards.items()
  176. if supports_all_engine_lists(model_info)
  177. ]