example_user_2.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # In this example, a user is running a home cluster with 3 shards.
  2. # They are prompting the cluster to generate a response to a question.
  3. # The cluster is given the question, and the user is given the response.
  4. from inference.mlx.sharded_utils import get_model_path, load_tokenizer
  5. from inference.shard import Shard
  6. from networking.peer_handle import PeerHandle
  7. from networking.grpc.grpc_peer_handle import GRPCPeerHandle
  8. from topology.device_capabilities import DeviceCapabilities
  9. from typing import List
  10. import asyncio
  11. import argparse
  12. path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
  13. model_path = get_model_path(path_or_hf_repo)
  14. tokenizer_config = {}
  15. tokenizer = load_tokenizer(model_path, tokenizer_config)
  16. peer = GRPCPeerHandle(
  17. "node1",
  18. "localhost:8080",
  19. DeviceCapabilities(model="test1", chip="test1", memory=10000)
  20. )
  21. shard = Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=0, n_layers=32)
  22. async def run_prompt(prompt: str):
  23. if tokenizer.chat_template is None:
  24. tokenizer.chat_template = tokenizer.default_chat_template
  25. if (
  26. hasattr(tokenizer, "apply_chat_template")
  27. and tokenizer.chat_template is not None
  28. ):
  29. messages = [{"role": "user", "content": prompt}]
  30. prompt = tokenizer.apply_chat_template(
  31. messages, tokenize=False, add_generation_prompt=True
  32. )
  33. await peer.connect()
  34. await peer.reset_shard(shard)
  35. result = await peer.send_prompt(shard, prompt)
  36. print(tokenizer.decode(result))
  37. if __name__ == "__main__":
  38. parser = argparse.ArgumentParser(description="Run prompt")
  39. parser.add_argument("--prompt", type=str, help="The prompt to run")
  40. args = parser.parse_args()
  41. asyncio.run(run_prompt(args.prompt))