|
@@ -9,6 +9,7 @@ from aiohttp import web
|
|
import aiohttp_cors
|
|
import aiohttp_cors
|
|
import traceback
|
|
import traceback
|
|
from exo import DEBUG, VERSION
|
|
from exo import DEBUG, VERSION
|
|
|
|
+from exo.download.download_progress import RepoProgressEvent
|
|
from exo.helpers import PrefixDict
|
|
from exo.helpers import PrefixDict
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
@@ -175,6 +176,8 @@ class ChatGPTAPI:
|
|
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
|
|
+ # Endpoint for download progress tracking
|
|
|
|
+ cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
|
|
|
|
|
self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat"
|
|
self.static_dir = Path(__file__).parent.parent.parent/"tinychat/examples/tinychat"
|
|
self.app.router.add_get("/", self.handle_root)
|
|
self.app.router.add_get("/", self.handle_root)
|
|
@@ -203,6 +206,20 @@ class ChatGPTAPI:
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
tokenizer = await resolve_tokenizer(shard.model_id)
|
|
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
|
|
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
|
|
|
|
|
|
|
|
+ async def handle_get_download_progress(self, request):
|
|
|
|
+ progress_data = {}
|
|
|
|
+ for node_id, progress_event in self.node.node_download_progress.items():
|
|
|
|
+ if isinstance(progress_event, RepoProgressEvent):
|
|
|
|
+ # Convert to dict if not already
|
|
|
|
+ progress_data[node_id] = progress_event.to_dict()
|
|
|
|
+ elif isinstance(progress_event, dict):
|
|
|
|
+ progress_data[node_id] = progress_event
|
|
|
|
+ else:
|
|
|
|
+ # Handle unexpected types
|
|
|
|
+ progress_data[node_id] = str(progress_event)
|
|
|
|
+ return web.json_response(progress_data)
|
|
|
|
+
|
|
|
|
+
|
|
async def handle_post_chat_completions(self, request):
|
|
async def handle_post_chat_completions(self, request):
|
|
data = await request.json()
|
|
data = await request.json()
|
|
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
|
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|