|
@@ -1,10 +1,12 @@
|
|
import uuid
|
|
import uuid
|
|
import time
|
|
import time
|
|
import asyncio
|
|
import asyncio
|
|
|
|
+import json
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
-from typing import List
|
|
|
|
|
|
+from typing import List, Literal, Union
|
|
from aiohttp import web
|
|
from aiohttp import web
|
|
-from exo import DEBUG
|
|
|
|
|
|
+import aiohttp_cors
|
|
|
|
+from exo import DEBUG, VERSION
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
from exo.orchestration import Node
|
|
from exo.orchestration import Node
|
|
|
|
|
|
@@ -59,18 +61,77 @@ def resolve_tokenizer(model_id: str):
|
|
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
|
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
|
return load_tokenizer(get_model_path(model_id))
|
|
return load_tokenizer(get_model_path(model_id))
|
|
|
|
|
|
|
|
+def generate_completion(
|
|
|
|
+ chat_request: ChatCompletionRequest,
|
|
|
|
+ tokenizer,
|
|
|
|
+ prompt: str,
|
|
|
|
+ request_id: str,
|
|
|
|
+ tokens: List[int],
|
|
|
|
+ stream: bool,
|
|
|
|
+ finish_reason: Union[Literal["length", "stop"], None],
|
|
|
|
+ object_type: Literal["chat.completion", "text_completion"]
|
|
|
|
+ ) -> dict:
|
|
|
|
+ completion = {
|
|
|
|
+ "id": f"chatcmpl-{request_id}",
|
|
|
|
+ "object": object_type,
|
|
|
|
+ "created": int(time.time()),
|
|
|
|
+ "model": chat_request.model,
|
|
|
|
+ "system_fingerprint": f"exo_{VERSION}",
|
|
|
|
+ "choices": [
|
|
|
|
+ {
|
|
|
|
+ "index": 0,
|
|
|
|
+ "message": {
|
|
|
|
+ "role": "assistant",
|
|
|
|
+ "content": tokenizer.decode(tokens)
|
|
|
|
+ },
|
|
|
|
+ "logprobs": None,
|
|
|
|
+ "finish_reason": finish_reason,
|
|
|
|
+ }
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if not stream:
|
|
|
|
+ completion["usage"] = {
|
|
|
|
+ "prompt_tokens": len(tokenizer.encode(prompt)),
|
|
|
|
+ "completion_tokens": len(tokens),
|
|
|
|
+ "total_tokens": len(tokenizer.encode(prompt)) + len(tokens)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ choice = completion["choices"][0]
|
|
|
|
+ if object_type.startswith("chat.completion"):
|
|
|
|
+ key_name = "delta" if stream else "message"
|
|
|
|
+ choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
|
|
|
|
+ elif object_type == "text_completion":
|
|
|
|
+ choice['text'] = tokenizer.decode(tokens)
|
|
|
|
+ else:
|
|
|
|
+ ValueError(f"Unsupported response type: {object_type}")
|
|
|
|
+
|
|
|
|
+ return completion
|
|
|
|
+
|
|
|
|
+
|
|
class ChatGPTAPI:
|
|
class ChatGPTAPI:
|
|
def __init__(self, node: Node, inference_engine_classname: str):
|
|
def __init__(self, node: Node, inference_engine_classname: str):
|
|
self.node = node
|
|
self.node = node
|
|
- self.app = web.Application()
|
|
|
|
- self.app.router.add_post('/v1/chat/completions', self.handle_post)
|
|
|
|
self.inference_engine_classname = inference_engine_classname
|
|
self.inference_engine_classname = inference_engine_classname
|
|
self.response_timeout_secs = 90
|
|
self.response_timeout_secs = 90
|
|
|
|
+ self.app = web.Application()
|
|
|
|
+ cors = aiohttp_cors.setup(self.app)
|
|
|
|
+ cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post), {
|
|
|
|
+ "*": aiohttp_cors.ResourceOptions(
|
|
|
|
+ allow_credentials=True,
|
|
|
|
+ expose_headers="*",
|
|
|
|
+ allow_headers="*",
|
|
|
|
+ allow_methods="*",
|
|
|
|
+ )
|
|
|
|
+ })
|
|
|
|
|
|
async def handle_post(self, request):
|
|
async def handle_post(self, request):
|
|
data = await request.json()
|
|
data = await request.json()
|
|
|
|
+ stream = data.get('stream', False)
|
|
messages = [Message(**msg) for msg in data['messages']]
|
|
messages = [Message(**msg) for msg in data['messages']]
|
|
chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
|
|
chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
|
|
|
|
+ if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
|
|
+ chat_request.model = "llama-3-8b"
|
|
prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
|
|
prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
|
|
shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
|
|
shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
|
|
if not shard:
|
|
if not shard:
|
|
@@ -83,6 +144,9 @@ class ChatGPTAPI:
|
|
chat_request.messages, tokenize=False, add_generation_prompt=True
|
|
chat_request.messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
|
|
+ callback = self.node.on_token.register(callback_id)
|
|
|
|
+
|
|
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
|
|
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
|
|
try:
|
|
try:
|
|
await self.node.process_prompt(shard, prompt, request_id=request_id)
|
|
await self.node.process_prompt(shard, prompt, request_id=request_id)
|
|
@@ -92,40 +156,64 @@ class ChatGPTAPI:
|
|
traceback.print_exc()
|
|
traceback.print_exc()
|
|
return web.json_response({'detail': f"Error processing prompt (see logs): {str(e)}"}, status=500)
|
|
return web.json_response({'detail': f"Error processing prompt (see logs): {str(e)}"}, status=500)
|
|
|
|
|
|
- callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
|
|
- callback = self.node.on_token.register(callback_id)
|
|
|
|
-
|
|
|
|
try:
|
|
try:
|
|
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
|
|
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
|
|
- _, result, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
|
|
|
|
-
|
|
|
|
- eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
|
|
- if DEBUG >= 2: print(f"Checking if end of result {result[-1]=} is {eos_token_id=}")
|
|
|
|
- if result[-1] == eos_token_id:
|
|
|
|
- result = result[:-1]
|
|
|
|
-
|
|
|
|
- return web.json_response({
|
|
|
|
- "id": f"chatcmpl-{request_id}",
|
|
|
|
- "object": "chat.completion",
|
|
|
|
- "created": int(time.time()),
|
|
|
|
- "model": chat_request.model,
|
|
|
|
- "usage": {
|
|
|
|
- "prompt_tokens": len(tokenizer.encode(prompt)),
|
|
|
|
- "completion_tokens": len(result),
|
|
|
|
- "total_tokens": len(tokenizer.encode(prompt)) + len(result)
|
|
|
|
- },
|
|
|
|
- "choices": [
|
|
|
|
- {
|
|
|
|
- "message": {
|
|
|
|
- "role": "assistant",
|
|
|
|
- "content": tokenizer.decode(result)
|
|
|
|
- },
|
|
|
|
- "logprobs": None,
|
|
|
|
- "finish_reason": "stop",
|
|
|
|
- "index": 0
|
|
|
|
|
|
+
|
|
|
|
+ if stream:
|
|
|
|
+ response = web.StreamResponse(
|
|
|
|
+ status=200,
|
|
|
|
+ reason="OK",
|
|
|
|
+ headers={
|
|
|
|
+ "Content-Type": "application/json",
|
|
|
|
+ "Cache-Control": "no-cache",
|
|
|
|
+ # "Access-Control-Allow-Origin": "*",
|
|
|
|
+ # "Access-Control-Allow-Methods": "*",
|
|
|
|
+ # "Access-Control-Allow-Headers": "*",
|
|
}
|
|
}
|
|
- ]
|
|
|
|
- })
|
|
|
|
|
|
+ )
|
|
|
|
+ await response.prepare(request)
|
|
|
|
+
|
|
|
|
+ stream_task = None
|
|
|
|
+ last_tokens_len = 0
|
|
|
|
+ async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
|
|
|
|
+ nonlocal last_tokens_len
|
|
|
|
+ prev_last_tokens_len = last_tokens_len
|
|
|
|
+ last_tokens_len = len(tokens)
|
|
|
|
+ new_tokens = tokens[prev_last_tokens_len:]
|
|
|
|
+ finish_reason = None
|
|
|
|
+ eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
|
|
+ if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
|
|
|
|
+ new_tokens = new_tokens[:-1]
|
|
|
|
+ if is_finished:
|
|
|
|
+ finish_reason = "stop"
|
|
|
|
+ if is_finished and not finish_reason:
|
|
|
|
+ finish_reason = "length"
|
|
|
|
+
|
|
|
|
+ completion = generate_completion(chat_request, tokenizer, prompt, request_id, new_tokens, stream, finish_reason, "chat.completion")
|
|
|
|
+ if DEBUG >= 2: print(f"Streaming completion: {completion}")
|
|
|
|
+ await response.write(f"data: {json.dumps(completion)}\n\n".encode())
|
|
|
|
+ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
|
|
|
|
+ nonlocal stream_task
|
|
|
|
+ stream_task = asyncio.create_task(stream_result(request_id, tokens, is_finished))
|
|
|
|
+
|
|
|
|
+ return _request_id == request_id and is_finished
|
|
|
|
+ _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
|
|
|
|
+ if stream_task: # in case there is still a stream task running, wait for it to complete
|
|
|
|
+ if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
|
|
|
|
+ await stream_task
|
|
|
|
+ await response.write_eof()
|
|
|
|
+ return response
|
|
|
|
+ else:
|
|
|
|
+ _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
|
|
|
|
+
|
|
|
|
+ finish_reason = "length"
|
|
|
|
+ eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
|
|
|
|
+ if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
|
|
|
|
+ if tokens[-1] == eos_token_id:
|
|
|
|
+ tokens = tokens[:-1]
|
|
|
|
+ finish_reason = "stop"
|
|
|
|
+
|
|
|
|
+ return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
|
|
except asyncio.TimeoutError:
|
|
except asyncio.TimeoutError:
|
|
return web.json_response({'detail': "Response generation timed out"}, status=408)
|
|
return web.json_response({'detail': "Response generation timed out"}, status=408)
|
|
finally:
|
|
finally:
|