example_user.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # In this example, a user is running a home cluster with 3 shards.
  2. # They are prompting the cluster to generate a response to a question.
  3. # The cluster is given the question, and the user is given the response.
  4. from inference.mlx.sharded_utils import get_model_path, load_tokenizer
  5. from inference.shard import Shard
  6. from networking.peer_handle import PeerHandle
  7. from networking.grpc.grpc_peer_handle import GRPCPeerHandle
  8. from topology.device_capabilities import DeviceCapabilities
  9. from typing import List
  10. import asyncio
  11. import argparse
  12. path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
  13. model_path = get_model_path(path_or_hf_repo)
  14. tokenizer_config = {}
  15. tokenizer = load_tokenizer(model_path, tokenizer_config)
  16. peers: List[PeerHandle] = [
  17. GRPCPeerHandle(
  18. "node1",
  19. "localhost:8080",
  20. DeviceCapabilities(model="test1", chip="test1", memory=10000)
  21. ),
  22. GRPCPeerHandle(
  23. "node2",
  24. "localhost:8081",
  25. DeviceCapabilities(model="test2", chip="test2", memory=20000)
  26. )
  27. ]
  28. shards: List[Shard] = [
  29. Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
  30. Shard(model_id=path_or_hf_repo, start_layer=16, end_layer=31, n_layers=32),
  31. # Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
  32. # Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
  33. ]
  34. async def run_prompt(prompt: str):
  35. if tokenizer.chat_template is None:
  36. tokenizer.chat_template = tokenizer.default_chat_template
  37. if (
  38. hasattr(tokenizer, "apply_chat_template")
  39. and tokenizer.chat_template is not None
  40. ):
  41. messages = [{"role": "user", "content": prompt}]
  42. prompt = tokenizer.apply_chat_template(
  43. messages, tokenize=False, add_generation_prompt=True
  44. )
  45. for peer, shard in zip(peers, shards):
  46. await peer.connect()
  47. await peer.reset_shard(shard)
  48. tokens = []
  49. last_output = prompt
  50. for _ in range(20):
  51. for peer, shard in zip(peers, shards):
  52. if isinstance(last_output, str):
  53. last_output = await peer.send_prompt(shard, last_output)
  54. print("prompt output:", last_output)
  55. else:
  56. last_output = await peer.send_tensor(shard, last_output)
  57. print("tensor output:", last_output)
  58. if not last_output:
  59. break
  60. tokens.append(last_output.item())
  61. print(tokenizer.decode(tokens))
  62. if __name__ == "__main__":
  63. parser = argparse.ArgumentParser(description="Run prompt")
  64. parser.add_argument("--prompt", type=str, help="The prompt to run")
  65. args = parser.parse_args()
  66. asyncio.run(run_prompt(args.prompt))