Browse Source

add endpoint to get number of encoded tokens

Alex Cheema 1 year ago
parent
commit
d4e0a7d14b
2 changed files with 39 additions and 13 deletions
  1. 29 13
      exo/api/chatgpt_api.py
  2. 10 0
      exo/helpers.py

+ 29 - 13
exo/api/chatgpt_api.py

@@ -8,6 +8,7 @@ from typing import List, Literal, Union
 from aiohttp import web
 import aiohttp_cors
 from exo import DEBUG, VERSION
+from exo.helpers import terminal_link
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 
@@ -109,6 +110,11 @@ def generate_completion(
 
     return completion
 
+def build_prompt(tokenizer, messages: List[Message]):
+    return tokenizer.apply_chat_template(
+        messages, tokenize=False, add_generation_prompt=True
+    )
+
 
 class ChatGPTAPI:
     def __init__(self, node: Node, inference_engine_classname: str):
@@ -117,13 +123,17 @@ class ChatGPTAPI:
         self.response_timeout_secs = 90
         self.app = web.Application()
         cors = aiohttp_cors.setup(self.app)
-        cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post), {
-            "*": aiohttp_cors.ResourceOptions(
-                allow_credentials=True,
-                expose_headers="*",
-                allow_headers="*",
-                allow_methods="*",
-            )
+        cors_options = aiohttp_cors.ResourceOptions(
+            allow_credentials=True,
+            expose_headers="*",
+            allow_headers="*",
+            allow_methods="*",
+        )
+        cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post_chat_completions), {
+            "*": cors_options
+        })
+        cors.add(self.app.router.add_post('/v1/chat/token/encode', self.handle_post_chat_token_encode), {
+            "*": cors_options
         })
         self.static_dir = Path(__file__).parent.parent.parent / 'tinychat/examples/tinychat'
         self.app.router.add_get('/', self.handle_root)
@@ -132,14 +142,20 @@ class ChatGPTAPI:
     async def handle_root(self, request):
         return web.FileResponse(self.static_dir / 'index.html')
 
-    async def handle_post(self, request):
+    async def handle_post_chat_token_encode(self, request):
+        data = await request.json()
+        shard = shard_mappings.get(data.get('model', 'llama-3-8b'), {}).get(self.inference_engine_classname)
+        messages = data.get('messages', [])
+        tokenizer = resolve_tokenizer(shard.model_id)
+        return web.json_response({'length': len(build_prompt(tokenizer, messages))})
+
+    async def handle_post_chat_completions(self, request):
         data = await request.json()
         stream = data.get('stream', False)
         messages = [Message(**msg) for msg in data['messages']]
         chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
         if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
             chat_request.model = "llama-3-8b"
-        prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
         shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
         if not shard:
             return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
@@ -147,10 +163,8 @@ class ChatGPTAPI:
 
         tokenizer = resolve_tokenizer(shard.model_id)
         if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
-        prompt = tokenizer.apply_chat_template(
-            chat_request.messages, tokenize=False, add_generation_prompt=True
-        )
 
+        prompt = build_prompt(tokenizer, messages)
         callback_id = f"chatgpt-api-wait-response-{request_id}"
         callback = self.node.on_token.register(callback_id)
 
@@ -235,4 +249,6 @@ class ChatGPTAPI:
         await runner.setup()
         site = web.TCPSite(runner, host, port)
         await site.start()
-        if DEBUG >= 0: print(f"ChatGPT API server started at http://{host}:{port}")
+        if DEBUG >= 0:
+            print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
+            print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")

+ 10 - 0
exo/helpers.py

@@ -6,6 +6,16 @@ DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
 VERSION = "0.0.1"
 
+def terminal_link(uri, label=None):
+    if label is None: 
+        label = uri
+    parameters = ''
+
+    # OSC 8 ; params ; URI ST <name> OSC 8 ;; ST 
+    escape_mask = '\033]8;{};{}\033\\{}\033]8;;\033\\'
+
+    return escape_mask.format(parameters, uri, label)
+
 T = TypeVar('T')
 K = TypeVar('K')