example_user_2.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # In this example, a user is running a home cluster with 3 shards.
  2. # They are prompting the cluster to generate a response to a question.
  3. # The cluster is given the question, and the user is given the response.
  4. from inference.mlx.sharded_utils import get_model_path, load_tokenizer
  5. from inference.shard import Shard
  6. from networking.peer_handle import PeerHandle
  7. from networking.grpc.grpc_peer_handle import GRPCPeerHandle
  8. from topology.device_capabilities import DeviceCapabilities
  9. from typing import List
  10. import asyncio
  11. import argparse
  12. path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
  13. model_path = get_model_path(path_or_hf_repo)
  14. tokenizer_config = {}
  15. tokenizer = load_tokenizer(model_path, tokenizer_config)
  16. peer1 = GRPCPeerHandle(
  17. "node1",
  18. "localhost:8080",
  19. DeviceCapabilities(model="test1", chip="test1", memory=10000)
  20. )
  21. peer2 = GRPCPeerHandle(
  22. "node2",
  23. "localhost:8081",
  24. DeviceCapabilities(model="test1", chip="test1", memory=10000)
  25. )
  26. shard = Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=0, n_layers=32)
  27. async def run_prompt(prompt: str):
  28. if tokenizer.chat_template is None:
  29. tokenizer.chat_template = tokenizer.default_chat_template
  30. if (
  31. hasattr(tokenizer, "apply_chat_template")
  32. and tokenizer.chat_template is not None
  33. ):
  34. messages = [{"role": "user", "content": prompt}]
  35. prompt = tokenizer.apply_chat_template(
  36. messages, tokenize=False, add_generation_prompt=True
  37. )
  38. for peer in [peer1, peer2]:
  39. await peer.connect()
  40. await peer.reset_shard(shard)
  41. try:
  42. await peer1.send_prompt(shard, prompt, "request-id-1")
  43. except Exception as e:
  44. print(e)
  45. import sys
  46. import time
  47. # poll 10 times per second for result (even though generation is faster, any more than this it's not nice for the user)
  48. previous_length = 0
  49. n_tokens = 0
  50. start_time = time.perf_counter()
  51. while True:
  52. result, is_finished = await peer2.get_inference_result("request-id-1")
  53. await asyncio.sleep(0.1)
  54. # Print the updated string in place
  55. updated_string = tokenizer.decode(result)
  56. n_tokens = len(result)
  57. print(updated_string[previous_length:], end='', flush=True)
  58. previous_length = len(updated_string)
  59. if is_finished:
  60. print("\nDone")
  61. break
  62. end_time = time.perf_counter()
  63. print(f"\nDone. Processed {n_tokens} tokens in {end_time - start_time:.2f} seconds ({n_tokens / (end_time - start_time):.2f} tokens/second)")
  64. if __name__ == "__main__":
  65. parser = argparse.ArgumentParser(description="Run prompt")
  66. parser.add_argument("--prompt", type=str, help="The prompt to run")
  67. args = parser.parse_args()
  68. asyncio.run(run_prompt(args.prompt))