Browse Source

Merge branch 'main' into better_networking

Alex Cheema 8 tháng trước cách đây
mục cha
commit
8baaad7f6b
2 tập tin đã thay đổi với 2 bổ sung2 xóa
  1. 1 1
      exo/inference/mlx/sharded_model.py
  2. 1 1
      exo/models.py

+ 1 - 1
exo/inference/mlx/sharded_model.py

@@ -8,7 +8,7 @@ from mlx_lm.sample_utils import top_p_sampling
 
 from ..shard import Shard
 
-
+# TODO: support a speculative model so we can parallelise compute across devices
 class StatefulShardedModel:
   def __init__(self, shard: Shard, model: nn.Module, max_kv_size: int = 1024, max_caches: int = 2):
     self.shard = shard

+ 1 - 1
exo/models.py

@@ -8,7 +8,7 @@ model_base_shards = {
   },
   "llama-3.1-70b": {
     "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
-    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B", start_layer=0, end_layer=0, n_layers=80),
+    "TinygradDynamicShardInferenceEngine": Shard(model_id="NousResearch/Meta-Llama-3.1-70B-Instruct", start_layer=0, end_layer=0, n_layers=80),
   },
   "llama-3.1-405b": {"MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3.1-405B-4bit", start_layer=0, end_layer=0, n_layers=126),},
   "llama-3-8b": {