浏览代码

add llama-3-70b to the examples

Alex Cheema 1 年之前
父节点
当前提交
b6595bac04
共有 1 个文件被更改,包括 19 次插入9 次删除
  1. 19 9
      example_user_2.py

+ 19 - 9
example_user_2.py

@@ -10,23 +10,30 @@ from exo.topology.device_capabilities import DeviceCapabilities
 from typing import List
 import asyncio
 import argparse
+import uuid
 
-path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
+models = {
+    "mlx-community/Meta-Llama-3-8B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+    "mlx-community/Meta-Llama-3-70B-Instruct-4bit": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80)
+}
+
+path_or_hf_repo = "mlx-community/Meta-Llama-3-70B-Instruct-4bit"
 model_path = get_model_path(path_or_hf_repo)
 tokenizer_config = {}
 tokenizer = load_tokenizer(model_path, tokenizer_config)
 
-peer1 = GRPCPeerHandle(
+peer2 = GRPCPeerHandle(
     "node1",
     "localhost:8080",
-    DeviceCapabilities(model="test1", chip="test1", memory=10000)
+    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 )
-peer2 = GRPCPeerHandle(
+peer1 = GRPCPeerHandle(
     "node2",
-    "localhost:8081",
-    DeviceCapabilities(model="test1", chip="test1", memory=10000)
+    "10.0.0.161:8080",
+    DeviceCapabilities(model="placeholder", chip="placeholder", memory=0)
 )
-shard = Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=0, n_layers=32)
+shard = models[path_or_hf_repo]
+request_id = str(uuid.uuid4())
 
 async def run_prompt(prompt: str):
     if tokenizer.chat_template is None:
@@ -45,7 +52,7 @@ async def run_prompt(prompt: str):
         await peer.reset_shard(shard)
 
     try:
-        await peer1.send_prompt(shard, prompt, "request-id-1")
+        await peer1.send_prompt(shard, prompt, request_id)
     except Exception as e:
         print(e)
 
@@ -56,7 +63,10 @@ async def run_prompt(prompt: str):
     n_tokens = 0
     start_time = time.perf_counter()
     while True:
-        result, is_finished = await peer2.get_inference_result("request-id-1")
+        try:
+            result, is_finished = await peer2.get_inference_result(request_id)
+        except Exception as e:
+            continue
         await asyncio.sleep(0.1)
 
         # Print the updated string in place