1
0
Эх сурвалжийг харах

Changed model classname due to the sharding being done elsewhere

Nel Nibcord 6 сар өмнө
parent
commit
52b91de817

+ 4 - 4
exo/inference/mlx/test_sharded_llama.py

@@ -1,5 +1,5 @@
 import mlx.core as mx
-from exo.inference.mlx.sharded_model import StatefulShardedModel
+from exo.inference.mlx.sharded_model import StatefulModel
 from exo.inference.mlx.sharded_utils import load_shard
 from exo.inference.shard import Shard
 
@@ -12,9 +12,9 @@ full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Ins
 model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
 model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
 
-full = StatefulShardedModel(shard_full, full_model_shard)
-m1 = StatefulShardedModel(shard1, model_shard1)
-m2 = StatefulShardedModel(shard2, model_shard2)
+full = StatefulModel(shard_full, full_model_shard)
+m1 = StatefulModel(shard1, model_shard1)
+m2 = StatefulModel(shard2, model_shard2)
 
 prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
 prompt_tokens = mx.array(full_tokenizer.encode(prompt))