|
@@ -18,7 +18,7 @@ from exo.inference.shard import Shard
|
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from exo.orchestration import Node
|
|
|
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
|
|
-from typing import Callable
|
|
|
+from typing import Callable, Optional
|
|
|
|
|
|
class Message:
|
|
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
@@ -148,7 +148,7 @@ class PromptSession:
|
|
|
self.prompt = prompt
|
|
|
|
|
|
class ChatGPTAPI:
|
|
|
- def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
|
|
|
+ def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
|
|
self.node = node
|
|
|
self.inference_engine_classname = inference_engine_classname
|
|
|
self.response_timeout = response_timeout
|
|
@@ -157,7 +157,7 @@ class ChatGPTAPI:
|
|
|
self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
|
|
|
self.prev_token_lens: Dict[str, int] = {}
|
|
|
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
|
|
- self.default_model = "llama-3.2-1b"
|
|
|
+ self.default_model = default_model or "llama-3.2-1b"
|
|
|
|
|
|
cors = aiohttp_cors.setup(self.app)
|
|
|
cors_options = aiohttp_cors.ResourceOptions(
|
|
@@ -257,8 +257,8 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
|
|
stream = data.get("stream", False)
|
|
|
chat_request = parse_chat_request(data, self.default_model)
|
|
|
- if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
|
|
|
- chat_request.model = self.default_model if self.default_model.startswith("llama") else "llama-3.2-1b"
|
|
|
+ if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
|
|
|
+ chat_request.model = self.default_model
|
|
|
if not chat_request.model or chat_request.model not in model_cards:
|
|
|
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
|
|
chat_request.model = self.default_model
|