|
@@ -11,10 +11,11 @@ import traceback
|
|
from exo import DEBUG, VERSION
|
|
from exo import DEBUG, VERSION
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
from exo.helpers import PrefixDict
|
|
from exo.helpers import PrefixDict
|
|
|
|
+from exo.inference.inference_engine import inference_engine_classes
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.shard import Shard
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.orchestration import Node
|
|
from exo.orchestration import Node
|
|
-from exo.models import build_base_shard, model_cards, get_repo
|
|
|
|
|
|
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
|
from typing import Callable
|
|
from typing import Callable
|
|
|
|
|
|
|
|
|
|
@@ -171,6 +172,7 @@ class ChatGPTAPI:
|
|
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
|
|
|
+ cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
|
|
|
|
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
self.app.router.add_get("/", self.handle_root)
|
|
self.app.router.add_get("/", self.handle_root)
|
|
@@ -198,6 +200,9 @@ class ChatGPTAPI:
|
|
async def handle_root(self, request):
|
|
async def handle_root(self, request):
|
|
return web.FileResponse(self.static_dir/"index.html")
|
|
return web.FileResponse(self.static_dir/"index.html")
|
|
|
|
|
|
|
|
+ async def handle_model_support(self, request):
|
|
|
|
+ return web.json_response({"model pool": { m: pretty_name.get(m, m) for m in [k for k,v in model_cards.items() if all(map(lambda e: e in v["repo"], list(dict.fromkeys([inference_engine_classes.get(i,None) for i in self.node.topology_inference_engines_pool for i in i if i is not None] + [self.inference_engine_classname]))))]}})
|
|
|
|
+
|
|
async def handle_get_models(self, request):
|
|
async def handle_get_models(self, request):
|
|
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|
|
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|
|
|
|
|