|
@@ -1,5 +1,5 @@
|
|
import mlx.core as mx
|
|
import mlx.core as mx
|
|
-from exo.inference.mlx.sharded_model import StatefulShardedModel
|
|
|
|
|
|
+from exo.inference.mlx.sharded_model import StatefulModel
|
|
from exo.inference.mlx.sharded_utils import load_shard
|
|
from exo.inference.mlx.sharded_utils import load_shard
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
|
|
|
|
@@ -12,9 +12,9 @@ full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Ins
|
|
model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
|
|
model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
|
|
model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
|
|
model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
|
|
|
|
|
|
-full = StatefulShardedModel(shard_full, full_model_shard)
|
|
|
|
-m1 = StatefulShardedModel(shard1, model_shard1)
|
|
|
|
-m2 = StatefulShardedModel(shard2, model_shard2)
|
|
|
|
|
|
+full = StatefulModel(shard_full, full_model_shard)
|
|
|
|
+m1 = StatefulModel(shard1, model_shard1)
|
|
|
|
+m2 = StatefulModel(shard2, model_shard2)
|
|
|
|
|
|
prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
|
|
prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
|
|
prompt_tokens = mx.array(full_tokenizer.encode(prompt))
|
|
prompt_tokens = mx.array(full_tokenizer.encode(prompt))
|