Browse Source

add --default-model command line arg

Alex Cheema 8 tháng trước cách đây
mục cha
commit
0ab302a35f
2 tập tin đã thay đổi với 8 bổ sung6 xóa
  1. 5 5
      exo/api/chatgpt_api.py
  2. 3 1
      exo/main.py

+ 5 - 5
exo/api/chatgpt_api.py

@@ -18,7 +18,7 @@ from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
 from exo.models import build_base_shard, model_cards, get_repo, pretty_name
-from typing import Callable
+from typing import Callable, Optional
 
 class Message:
   def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
@@ -148,7 +148,7 @@ class PromptSession:
     self.prompt = prompt
 
 class ChatGPTAPI:
-  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None):
+  def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
     self.node = node
     self.inference_engine_classname = inference_engine_classname
     self.response_timeout = response_timeout
@@ -157,7 +157,7 @@ class ChatGPTAPI:
     self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
     self.prev_token_lens: Dict[str, int] = {}
     self.stream_tasks: Dict[str, asyncio.Task] = {}
-    self.default_model = "llama-3.2-1b"
+    self.default_model = default_model or "llama-3.2-1b"
 
     cors = aiohttp_cors.setup(self.app)
     cors_options = aiohttp_cors.ResourceOptions(
@@ -257,8 +257,8 @@ class ChatGPTAPI:
     if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
     stream = data.get("stream", False)
     chat_request = parse_chat_request(data, self.default_model)
-    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to llama instead
-      chat_request.model = self.default_model if self.default_model.startswith("llama") else "llama-3.2-1b"
+    if chat_request.model and chat_request.model.startswith("gpt-"):  # to be compatible with ChatGPT tools, point all gpt- model requests to default model
+      chat_request.model = self.default_model
     if not chat_request.model or chat_request.model not in model_cards:
       if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
       chat_request.model = self.default_model

+ 3 - 1
exo/main.py

@@ -33,6 +33,7 @@ from exo.viz.topology_viz import TopologyViz
 parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
 parser.add_argument("command", nargs="?", choices=["run"], help="Command to run")
 parser.add_argument("model_name", nargs="?", help="Model name to run")
+parser.add_argument("--default-model", type=str, default=None, help="Default model")
 parser.add_argument("--node-id", type=str, default=None, help="Node ID")
 parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
 parser.add_argument("--node-port", type=int, default=None, help="Node port")
@@ -125,7 +126,8 @@ api = ChatGPTAPI(
   node,
   inference_engine.__class__.__name__,
   response_timeout=args.chatgpt_api_response_timeout,
-  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None
+  on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
+  default_model=args.default_model
 )
 node.on_token.register("update_topology_viz").on_next(
   lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None