example_user.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # In this example, a user is running a home cluster with 3 shards.
  2. # They are prompting the cluster to generate a response to a question.
  3. # The cluster is given the question, and the user is given the response.
  4. from inference.mlx.sharded_utils import get_model_path, load_tokenizer
  5. from inference.shard import Shard
  6. from networking.peer_handle import PeerHandle
  7. from networking.grpc.grpc_peer_handle import GRPCPeerHandle
  8. from typing import List
  9. import asyncio
  10. import argparse
  11. path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
  12. model_path = get_model_path(path_or_hf_repo)
  13. tokenizer_config = {}
  14. tokenizer = load_tokenizer(model_path, tokenizer_config)
  15. peers: List[PeerHandle] = [
  16. GRPCPeerHandle(
  17. "node1",
  18. "localhost:8080",
  19. ),
  20. GRPCPeerHandle(
  21. "node2",
  22. "localhost:8081",
  23. )
  24. ]
  25. shards: List[Shard] = [
  26. # Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
  27. # Shard(model_id=path_or_hf_repo, start_layer=16, end_layer=31, n_layers=32),
  28. Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
  29. Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
  30. ]
  31. async def run_prompt(prompt: str):
  32. if tokenizer.chat_template is None:
  33. tokenizer.chat_template = tokenizer.default_chat_template
  34. if (
  35. hasattr(tokenizer, "apply_chat_template")
  36. and tokenizer.chat_template is not None
  37. ):
  38. messages = [{"role": "user", "content": prompt}]
  39. prompt = tokenizer.apply_chat_template(
  40. messages, tokenize=False, add_generation_prompt=True
  41. )
  42. for peer, shard in zip(peers, shards):
  43. await peer.connect()
  44. await peer.reset_shard(shard)
  45. tokens = []
  46. last_output = prompt
  47. for _ in range(20):
  48. for peer, shard in zip(peers, shards):
  49. if isinstance(last_output, str):
  50. last_output = await peer.send_prompt(shard, last_output)
  51. print("prompt output:", last_output)
  52. else:
  53. last_output = await peer.send_tensor(shard, last_output)
  54. print("tensor output:", last_output)
  55. if not last_output:
  56. break
  57. tokens.append(last_output.item())
  58. print(tokenizer.decode(tokens))
  59. if __name__ == "__main__":
  60. parser = argparse.ArgumentParser(description="Run prompt")
  61. parser.add_argument("--prompt", type=str, help="The prompt to run")
  62. args = parser.parse_args()
  63. asyncio.run(run_prompt(args.prompt))