浏览代码

fix tokenizer inconsistencies

Alex Cheema 9 月之前
父节点
当前提交
71e00745cc
共有 4 个文件被更改,包括 51 次插入12 次删除
  1. 46 8
      exo/api/chatgpt_api.py
  2. 1 1
      exo/orchestration/standard_node.py
  3. 3 2
      exo/topology/device_capabilities.py
  4. 1 1
      main.py

+ 46 - 8
exo/api/chatgpt_api.py

@@ -1,6 +1,7 @@
 import uuid
 import uuid
 import time
 import time
 import asyncio
 import asyncio
+from transformers import AutoTokenizer
 from typing import List
 from typing import List
 from aiohttp import web
 from aiohttp import web
 from exo import DEBUG
 from exo import DEBUG
@@ -8,8 +9,14 @@ from exo.inference.shard import Shard
 from exo.orchestration import Node
 from exo.orchestration import Node
 
 
 shard_mappings = {
 shard_mappings = {
-    "llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
-    "llama-3-70b": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+    "llama-3-8b": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
+        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-8b-sfr", start_layer=0, end_layer=0, n_layers=32),
+    },
+    "llama-3-70b": {
+        "MLXDynamicShardInferenceEngine": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
+        "TinygradDynamicShardInferenceEngine": Shard(model_id="llama3-70b-sfr", start_layer=0, end_layer=0, n_layers=80),
+    },
 }
 }
 
 
 class Message:
 class Message:
@@ -23,25 +30,54 @@ class ChatCompletionRequest:
         self.messages = messages
         self.messages = messages
         self.temperature = temperature
         self.temperature = temperature
 
 
+def resolve_tinygrad_tokenizer(model_id: str):
+    if model_id == "llama3-8b-sfr":
+        return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
+    elif model_id == "llama3-70b-sfr":
+        return AutoTokenizer.from_pretrained("TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R")
+    else:
+        raise ValueError(f"tinygrad doesnt currently support arbitrary model downloading. unsupported model: {model_id}")
+
+def resolve_tokenizer(model_id: str):
+    try:
+        if DEBUG >= 2: print(f"Trying AutoTokenizer for {model_id}")
+        return AutoTokenizer.from_pretrained(model_id)
+    except:
+        import traceback
+        if DEBUG >= 2: print(traceback.format_exc())
+        if DEBUG >= 2: print(f"Failed to load tokenizer for {model_id}. Falling back to tinygrad tokenizer")
+
+    try:
+        if DEBUG >= 2: print(f"Trying tinygrad tokenizer for {model_id}")
+        return resolve_tinygrad_tokenizer(model_id)
+    except:
+        import traceback
+        if DEBUG >= 2: print(traceback.format_exc())
+        if DEBUG >= 2: print(f"Failed again to load tokenizer for {model_id}. Falling back to mlx tokenizer")
+
+    if DEBUG >= 2: print(f"Trying mlx tokenizer for {model_id}")
+    from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
+    return load_tokenizer(get_model_path(model_id))
+
 class ChatGPTAPI:
 class ChatGPTAPI:
-    def __init__(self, node: Node):
+    def __init__(self, node: Node, inference_engine_classname: str):
         self.node = node
         self.node = node
         self.app = web.Application()
         self.app = web.Application()
         self.app.router.add_post('/v1/chat/completions', self.handle_post)
         self.app.router.add_post('/v1/chat/completions', self.handle_post)
+        self.inference_engine_classname = inference_engine_classname
 
 
     async def handle_post(self, request):
     async def handle_post(self, request):
         data = await request.json()
         data = await request.json()
         messages = [Message(**msg) for msg in data['messages']]
         messages = [Message(**msg) for msg in data['messages']]
         chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
         chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
         prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
         prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
-        shard = shard_mappings.get(chat_request.model)
+        shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
         if not shard:
         if not shard:
             return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
             return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
         request_id = str(uuid.uuid4())
         request_id = str(uuid.uuid4())
 
 
-        # TODO equivalent for non-mlx since the user can't install this on non-macs even though the tokenizer itself
-        from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
-        tokenizer = load_tokenizer(get_model_path(shard.model_id))
+        tokenizer = resolve_tokenizer(shard.model_id)
+        if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
         prompt = tokenizer.apply_chat_template(
         prompt = tokenizer.apply_chat_template(
             chat_request.messages, tokenize=False, add_generation_prompt=True
             chat_request.messages, tokenize=False, add_generation_prompt=True
         )
         )
@@ -63,7 +99,9 @@ class ChatGPTAPI:
                 continue
                 continue
             await asyncio.sleep(0.1)
             await asyncio.sleep(0.1)
             if is_finished:
             if is_finished:
-                if result[-1] == tokenizer._tokenizer.eos_token_id:
+                eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+                if DEBUG >= 2: print(f"Checking if end of result {result[-1]=} is {eos_token_id=}")
+                if result[-1] == eos_token_id:
                     result = result[:-1]
                     result = result[:-1]
                 return web.json_response({
                 return web.json_response({
                     "id": f"chatcmpl-{request_id}",
                     "id": f"chatcmpl-{request_id}",

+ 1 - 1
exo/orchestration/standard_node.py

@@ -178,7 +178,7 @@ class StandardNode(Node):
     async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
     async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
         self.topology.update_node(self.id, self.device_capabilities)
         self.topology.update_node(self.id, self.device_capabilities)
 
 
-        if DEBUG >= 2: print(f"Collecting topoloy {max_depth=} {visited=}")
+        if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
 
 
         prev_visited = visited.copy()
         prev_visited = visited.copy()
         visited.update(p.id() for p in self.peers)
         visited.update(p.id() for p in self.peers)

+ 3 - 2
exo/topology/device_capabilities.py

@@ -1,3 +1,4 @@
+from exo import DEBUG
 from dataclasses import dataclass
 from dataclasses import dataclass
 import subprocess
 import subprocess
 import platform
 import platform
@@ -41,8 +42,8 @@ def mac_device_capabilities() -> DeviceCapabilities:
 def linux_device_capabilities() -> DeviceCapabilities:
 def linux_device_capabilities() -> DeviceCapabilities:
     import psutil
     import psutil
     from tinygrad import Device
     from tinygrad import Device
-    
-    print(f"tinygrad {Device.DEFAULT=}")
+
+    if DEBUG >= 2: print(f"tinygrad {Device.DEFAULT=}")
     if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT=="GPU":
     if Device.DEFAULT == "CUDA" or Device.DEFAULT == "NV" or Device.DEFAULT=="GPU":
         import pynvml, pynvml_utils
         import pynvml, pynvml_utils
         pynvml.nvmlInit()
         pynvml.nvmlInit()

+ 1 - 1
main.py

@@ -38,7 +38,7 @@ node = StandardNode(args.node_id, None, inference_engine, discovery, partitionin
 server = GRPCServer(node, args.node_host, args.node_port)
 server = GRPCServer(node, args.node_host, args.node_port)
 node.server = server
 node.server = server
 
 
-api = ChatGPTAPI(node)
+api = ChatGPTAPI(node, inference_engine.__class__.__name__)
 
 
 async def shutdown(signal, loop):
 async def shutdown(signal, loop):
     """Gracefully shutdown the server and close the asyncio loop."""
     """Gracefully shutdown the server and close the asyncio loop."""