|
@@ -8,6 +8,7 @@ from typing import List, Literal, Union
|
|
|
from aiohttp import web
|
|
|
import aiohttp_cors
|
|
|
from exo import DEBUG, VERSION
|
|
|
+from exo.helpers import terminal_link
|
|
|
from exo.inference.shard import Shard
|
|
|
from exo.orchestration import Node
|
|
|
|
|
@@ -109,6 +110,11 @@ def generate_completion(
|
|
|
|
|
|
return completion
|
|
|
|
|
|
+def build_prompt(tokenizer, messages: List[Message]):
|
|
|
+ return tokenizer.apply_chat_template(
|
|
|
+ messages, tokenize=False, add_generation_prompt=True
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
class ChatGPTAPI:
|
|
|
def __init__(self, node: Node, inference_engine_classname: str):
|
|
@@ -117,13 +123,17 @@ class ChatGPTAPI:
|
|
|
self.response_timeout_secs = 90
|
|
|
self.app = web.Application()
|
|
|
cors = aiohttp_cors.setup(self.app)
|
|
|
- cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post), {
|
|
|
- "*": aiohttp_cors.ResourceOptions(
|
|
|
- allow_credentials=True,
|
|
|
- expose_headers="*",
|
|
|
- allow_headers="*",
|
|
|
- allow_methods="*",
|
|
|
- )
|
|
|
+ cors_options = aiohttp_cors.ResourceOptions(
|
|
|
+ allow_credentials=True,
|
|
|
+ expose_headers="*",
|
|
|
+ allow_headers="*",
|
|
|
+ allow_methods="*",
|
|
|
+ )
|
|
|
+ cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post_chat_completions), {
|
|
|
+ "*": cors_options
|
|
|
+ })
|
|
|
+ cors.add(self.app.router.add_post('/v1/chat/token/encode', self.handle_post_chat_token_encode), {
|
|
|
+ "*": cors_options
|
|
|
})
|
|
|
self.static_dir = Path(__file__).parent.parent.parent / 'tinychat/examples/tinychat'
|
|
|
self.app.router.add_get('/', self.handle_root)
|
|
@@ -132,14 +142,20 @@ class ChatGPTAPI:
|
|
|
async def handle_root(self, request):
|
|
|
return web.FileResponse(self.static_dir / 'index.html')
|
|
|
|
|
|
- async def handle_post(self, request):
|
|
|
+ async def handle_post_chat_token_encode(self, request):
|
|
|
+ data = await request.json()
|
|
|
+ shard = shard_mappings.get(data.get('model', 'llama-3-8b'), {}).get(self.inference_engine_classname)
|
|
|
+ messages = data.get('messages', [])
|
|
|
+ tokenizer = resolve_tokenizer(shard.model_id)
|
|
|
+ return web.json_response({'length': len(build_prompt(tokenizer, messages))})
|
|
|
+
|
|
|
+ async def handle_post_chat_completions(self, request):
|
|
|
data = await request.json()
|
|
|
stream = data.get('stream', False)
|
|
|
messages = [Message(**msg) for msg in data['messages']]
|
|
|
chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
|
|
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
|
chat_request.model = "llama-3-8b"
|
|
|
- prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
|
|
|
shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
|
|
|
if not shard:
|
|
|
return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
|
|
@@ -147,10 +163,8 @@ class ChatGPTAPI:
|
|
|
|
|
|
tokenizer = resolve_tokenizer(shard.model_id)
|
|
|
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
|
- prompt = tokenizer.apply_chat_template(
|
|
|
- chat_request.messages, tokenize=False, add_generation_prompt=True
|
|
|
- )
|
|
|
|
|
|
+ prompt = build_prompt(tokenizer, messages)
|
|
|
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
|
callback = self.node.on_token.register(callback_id)
|
|
|
|
|
@@ -235,4 +249,6 @@ class ChatGPTAPI:
|
|
|
await runner.setup()
|
|
|
site = web.TCPSite(runner, host, port)
|
|
|
await site.start()
|
|
|
- if DEBUG >= 0: print(f"ChatGPT API server started at http://{host}:{port}")
|
|
|
+ if DEBUG >= 0:
|
|
|
+ print(f"Chat interface started. Open this link in your browser: {terminal_link(f'http://localhost:{port}')}")
|
|
|
+ print(f"ChatGPT API endpoint served at {terminal_link(f'http://localhost:{port}/v1/chat/completions')}")
|