|
@@ -8,6 +8,9 @@ from typing import List, Literal, Union, Dict
|
|
|
from aiohttp import web
|
|
|
import aiohttp_cors
|
|
|
import traceback
|
|
|
+import os
|
|
|
+import signal
|
|
|
+import sys
|
|
|
from exo import DEBUG, VERSION
|
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
|
from exo.helpers import PrefixDict
|
|
@@ -18,7 +21,6 @@ from exo.orchestration import Node
|
|
|
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
|
|
from typing import Callable
|
|
|
|
|
|
-
|
|
|
class Message:
|
|
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
|
self.role = role
|
|
@@ -28,6 +30,7 @@ class Message:
|
|
|
return {"role": self.role, "content": self.content}
|
|
|
|
|
|
|
|
|
+
|
|
|
class ChatCompletionRequest:
|
|
|
def __init__(self, model: str, messages: List[Message], temperature: float):
|
|
|
self.model = model
|
|
@@ -145,6 +148,10 @@ class PromptSession:
|
|
|
self.timestamp = timestamp
|
|
|
self.prompt = prompt
|
|
|
|
|
|
+def is_frozen():
|
|
|
+ return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
|
|
|
+ or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
|
|
|
+ or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
|
|
|
|
|
|
class ChatGPTAPI:
|
|
|
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
|
|
@@ -174,13 +181,25 @@ class ChatGPTAPI:
|
|
|
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
|
|
+ cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
|
|
|
|
|
|
- self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
|
- self.app.router.add_get("/", self.handle_root)
|
|
|
- self.app.router.add_static("/", self.static_dir, name="static")
|
|
|
+ if "__compiled__" not in globals():
|
|
|
+ self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
|
+ self.app.router.add_get("/", self.handle_root)
|
|
|
+ self.app.router.add_static("/", self.static_dir, name="static")
|
|
|
|
|
|
self.app.middlewares.append(self.timeout_middleware)
|
|
|
self.app.middlewares.append(self.log_request)
|
|
|
+
|
|
|
+ async def handle_quit(self, request):
|
|
|
+ print("Received quit signal")
|
|
|
+ from exo.main import shutdown
|
|
|
+ response = web.json_response({"detail": "Quit signal received"}, status=200)
|
|
|
+ await response.prepare(request)
|
|
|
+ await response.write_eof()
|
|
|
+
|
|
|
+ await shutdown(signal.SIGINT, asyncio.get_event_loop())
|
|
|
+ return response
|
|
|
|
|
|
async def timeout_middleware(self, app, handler):
|
|
|
async def middleware(request):
|