Browse Source

chatgpt api repsonse streaming solves #20

Alex Cheema 9 months ago
parent
commit
8762effaf4
4 changed files with 126 additions and 36 deletions
  1. 1 1
      exo/__init__.py
  2. 123 35
      exo/api/chatgpt_api.py
  3. 1 0
      exo/helpers.py
  4. 1 0
      setup.py

+ 1 - 1
exo/__init__.py

@@ -1 +1 @@
-from exo.helpers import DEBUG, DEBUG_DISCOVERY
+from exo.helpers import DEBUG, DEBUG_DISCOVERY, VERSION

+ 123 - 35
exo/api/chatgpt_api.py

@@ -1,10 +1,12 @@
 import uuid
 import uuid
 import time
 import time
 import asyncio
 import asyncio
+import json
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
-from typing import List
+from typing import List, Literal, Union
 from aiohttp import web
 from aiohttp import web
-from exo import DEBUG
+import aiohttp_cors
+from exo import DEBUG, VERSION
 from exo.inference.shard import Shard
 from exo.inference.shard import Shard
 from exo.orchestration import Node
 from exo.orchestration import Node
 
 
@@ -59,18 +61,77 @@ def resolve_tokenizer(model_id: str):
     from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
     from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
     return load_tokenizer(get_model_path(model_id))
     return load_tokenizer(get_model_path(model_id))
 
 
+def generate_completion(
+        chat_request: ChatCompletionRequest,
+        tokenizer,
+        prompt: str,
+        request_id: str,
+        tokens: List[int],
+        stream: bool,
+        finish_reason: Union[Literal["length", "stop"], None],
+        object_type: Literal["chat.completion", "text_completion"]
+    ) -> dict:
+    completion = {
+        "id": f"chatcmpl-{request_id}",
+        "object": object_type,
+        "created": int(time.time()),
+        "model": chat_request.model,
+        "system_fingerprint": f"exo_{VERSION}",
+        "choices": [
+            {
+                "index": 0,
+                "message": {
+                    "role": "assistant",
+                    "content": tokenizer.decode(tokens)
+                },
+                "logprobs": None,
+                "finish_reason": finish_reason,
+            }
+        ]
+    }
+
+    if not stream:
+        completion["usage"] = {
+            "prompt_tokens": len(tokenizer.encode(prompt)),
+            "completion_tokens": len(tokens),
+            "total_tokens": len(tokenizer.encode(prompt)) + len(tokens)
+        }
+
+    choice = completion["choices"][0]
+    if object_type.startswith("chat.completion"):
+        key_name = "delta" if stream else "message"
+        choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
+    elif object_type == "text_completion":
+        choice['text'] = tokenizer.decode(tokens)
+    else:
+        ValueError(f"Unsupported response type: {object_type}")
+
+    return completion
+
+
 class ChatGPTAPI:
 class ChatGPTAPI:
     def __init__(self, node: Node, inference_engine_classname: str):
     def __init__(self, node: Node, inference_engine_classname: str):
         self.node = node
         self.node = node
-        self.app = web.Application()
-        self.app.router.add_post('/v1/chat/completions', self.handle_post)
         self.inference_engine_classname = inference_engine_classname
         self.inference_engine_classname = inference_engine_classname
         self.response_timeout_secs = 90
         self.response_timeout_secs = 90
+        self.app = web.Application()
+        cors = aiohttp_cors.setup(self.app)
+        cors.add(self.app.router.add_post('/v1/chat/completions', self.handle_post), {
+            "*": aiohttp_cors.ResourceOptions(
+                allow_credentials=True,
+                expose_headers="*",
+                allow_headers="*",
+                allow_methods="*",
+            )
+        })
 
 
     async def handle_post(self, request):
     async def handle_post(self, request):
         data = await request.json()
         data = await request.json()
+        stream = data.get('stream', False)
         messages = [Message(**msg) for msg in data['messages']]
         messages = [Message(**msg) for msg in data['messages']]
         chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
         chat_request = ChatCompletionRequest(data.get('model', 'llama-3-8b'), messages, data.get('temperature', 0.0))
+        if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
+            chat_request.model = "llama-3-8b"
         prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
         prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
         shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
         shard = shard_mappings.get(chat_request.model, {}).get(self.inference_engine_classname)
         if not shard:
         if not shard:
@@ -83,6 +144,9 @@ class ChatGPTAPI:
             chat_request.messages, tokenize=False, add_generation_prompt=True
             chat_request.messages, tokenize=False, add_generation_prompt=True
         )
         )
 
 
+        callback_id = f"chatgpt-api-wait-response-{request_id}"
+        callback = self.node.on_token.register(callback_id)
+
         if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
         if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
         try:
         try:
             await self.node.process_prompt(shard, prompt, request_id=request_id)
             await self.node.process_prompt(shard, prompt, request_id=request_id)
@@ -92,40 +156,64 @@ class ChatGPTAPI:
                 traceback.print_exc()
                 traceback.print_exc()
             return web.json_response({'detail': f"Error processing prompt (see logs): {str(e)}"}, status=500)
             return web.json_response({'detail': f"Error processing prompt (see logs): {str(e)}"}, status=500)
 
 
-        callback_id = f"chatgpt-api-wait-response-{request_id}"
-        callback = self.node.on_token.register(callback_id)
-
         try:
         try:
             if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
             if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout_secs}s")
-            _, result, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
-
-            eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
-            if DEBUG >= 2: print(f"Checking if end of result {result[-1]=} is {eos_token_id=}")
-            if result[-1] == eos_token_id:
-                result = result[:-1]
-
-            return web.json_response({
-                "id": f"chatcmpl-{request_id}",
-                "object": "chat.completion",
-                "created": int(time.time()),
-                "model": chat_request.model,
-                "usage": {
-                    "prompt_tokens": len(tokenizer.encode(prompt)),
-                    "completion_tokens": len(result),
-                    "total_tokens": len(tokenizer.encode(prompt)) + len(result)
-                },
-                "choices": [
-                    {
-                        "message": {
-                            "role": "assistant",
-                            "content": tokenizer.decode(result)
-                        },
-                        "logprobs": None,
-                        "finish_reason": "stop",
-                        "index": 0
+
+            if stream:
+                response = web.StreamResponse(
+                    status=200,
+                    reason="OK",
+                    headers={
+                        "Content-Type": "application/json",
+                        "Cache-Control": "no-cache",
+                        # "Access-Control-Allow-Origin": "*",
+                        # "Access-Control-Allow-Methods": "*",
+                        # "Access-Control-Allow-Headers": "*",
                     }
                     }
-                ]
-            })
+                )
+                await response.prepare(request)
+
+                stream_task = None
+                last_tokens_len = 0
+                async def stream_result(request_id: str, tokens: List[int], is_finished: bool):
+                    nonlocal last_tokens_len
+                    prev_last_tokens_len = last_tokens_len
+                    last_tokens_len = len(tokens)
+                    new_tokens = tokens[prev_last_tokens_len:]
+                    finish_reason = None
+                    eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+                    if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
+                        new_tokens = new_tokens[:-1]
+                        if is_finished:
+                            finish_reason = "stop"
+                    if is_finished and not finish_reason:
+                        finish_reason = "length"
+
+                    completion = generate_completion(chat_request, tokenizer, prompt, request_id, new_tokens, stream, finish_reason, "chat.completion")
+                    if DEBUG >= 2: print(f"Streaming completion: {completion}")
+                    await response.write(f"data: {json.dumps(completion)}\n\n".encode())
+                def on_result(_request_id: str, tokens: List[int], is_finished: bool):
+                    nonlocal stream_task
+                    stream_task = asyncio.create_task(stream_result(request_id, tokens, is_finished))
+
+                    return _request_id == request_id and is_finished
+                _, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout_secs)
+                if stream_task: # in case there is still a stream task running, wait for it to complete
+                    if DEBUG >= 2: print(f"Pending stream task. Waiting for stream task to complete.")
+                    await stream_task
+                await response.write_eof()
+                return response
+            else:
+                _, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=self.response_timeout_secs)
+
+                finish_reason = "length"
+                eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(tokenizer._tokenizer, AutoTokenizer) else tokenizer.eos_token_id
+                if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
+                if tokens[-1] == eos_token_id:
+                    tokens = tokens[:-1]
+                    finish_reason = "stop"
+
+                return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
         except asyncio.TimeoutError:
         except asyncio.TimeoutError:
             return web.json_response({'detail': "Response generation timed out"}, status=408)
             return web.json_response({'detail': "Response generation timed out"}, status=408)
         finally:
         finally:

+ 1 - 0
exo/helpers.py

@@ -4,6 +4,7 @@ from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, T
 
 
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG = int(os.getenv("DEBUG", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
 DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
+VERSION = "0.0.1"
 
 
 T = TypeVar('T')
 T = TypeVar('T')
 K = TypeVar('K')
 K = TypeVar('K')

+ 1 - 0
setup.py

@@ -4,6 +4,7 @@ import sys
 # Base requirements for all platforms
 # Base requirements for all platforms
 install_requires = [
 install_requires = [
     "aiohttp==3.9.5",
     "aiohttp==3.9.5",
+    "aiohttp_cors==0.7.0",
     "grpcio==1.64.1",
     "grpcio==1.64.1",
     "grpcio-tools==1.64.1",
     "grpcio-tools==1.64.1",
     "huggingface-hub==0.23.4",
     "huggingface-hub==0.23.4",