test_sharded_llama.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import mlx.core as mx
  2. from exo.inference.mlx.sharded_model import StatefulShardedModel
  3. from exo.inference.mlx.sharded_utils import load_shard
  4. from exo.inference.shard import Shard
  5. # 79, 80 for Llama-3-70B
  6. shard_full = Shard("llama", 0, 31, 32)
  7. shard1 = Shard("llama", 0, 12, 32)
  8. shard2 = Shard("llama", 13, 31, 32)
  9. full_model_shard, full_tokenizer = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard_full)
  10. model_shard1, tokenizer1 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard1)
  11. model_shard2, tokenizer2 = load_shard("mlx-community/Meta-Llama-3-8B-Instruct-4bit", shard=shard2)
  12. full = StatefulShardedModel(shard_full, full_model_shard)
  13. m1 = StatefulShardedModel(shard1, model_shard1)
  14. m2 = StatefulShardedModel(shard2, model_shard2)
  15. prompt = "write a beautiful haiku about a utopia where people own their AI with edge intelligence:"
  16. prompt_tokens = mx.array(full_tokenizer.encode(prompt))
  17. max_tokens = 50
  18. resp = prompt_tokens
  19. full_generated_tokens = []
  20. for _ in range(max_tokens):
  21. resp = full.step(resp)
  22. full_generated_tokens.append(resp.item())
  23. print("full response: ", full_tokenizer.decode(full_generated_tokens))
  24. sharded_generated_tokens = []
  25. sharded_resp = prompt_tokens
  26. for _ in range(max_tokens):
  27. resp1 = m1.step(sharded_resp)
  28. sharded_resp = m2.step(resp1)
  29. sharded_generated_tokens.append(sharded_resp.item())
  30. print("sharded response: ", tokenizer1.decode(sharded_generated_tokens))
  31. assert tokenizer1.decode(full_generated_tokens) == tokenizer1.decode(sharded_generated_tokens)