Browse Source

First pass at a dynamic model menu in tinychat

Nel Nibcord 7 months ago
parent
commit
b0dc94477a
5 changed files with 68 additions and 31 deletions
  1. 6 1
      exo/api/chatgpt_api.py
  2. 5 0
      exo/inference/inference_engine.py
  3. 31 0
      exo/models.py
  4. 3 30
      exo/tinychat/index.html
  5. 23 0
      exo/tinychat/index.js

+ 6 - 1
exo/api/chatgpt_api.py

@@ -11,10 +11,11 @@ import traceback
 from exo import DEBUG, VERSION
 from exo.download.download_progress import RepoProgressEvent
 from exo.helpers import PrefixDict
+from exo.inference.inference_engine import inference_engine_classes
 from exo.inference.shard import Shard
 from exo.inference.tokenizers import resolve_tokenizer
 from exo.orchestration import Node
-from exo.models import build_base_shard, model_cards, get_repo
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name
 from typing import Callable
 
 
@@ -171,6 +172,7 @@ class ChatGPTAPI:
     cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
     cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
+    cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
 
     self.static_dir = Path(__file__).parent.parent/"tinychat"
     self.app.router.add_get("/", self.handle_root)
@@ -198,6 +200,9 @@ class ChatGPTAPI:
   async def handle_root(self, request):
     return web.FileResponse(self.static_dir/"index.html")
 
+  async def handle_model_support(self, request):
+    return web.json_response({"model pool": { m: pretty_name.get(m, m) for m in [k for k,v in model_cards.items() if all(map(lambda e: e in v["repo"], list(dict.fromkeys([inference_engine_classes.get(i,None) for i in self.node.topology_inference_engines_pool for i in i if i is not None] + [self.inference_engine_classname]))))]}})
+  
   async def handle_get_models(self, request):
     return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
 

+ 5 - 0
exo/inference/inference_engine.py

@@ -29,6 +29,11 @@ class InferenceEngine(ABC):
     output_data = await self.infer_tensor(request_id, shard, tokens)
     return output_data 
 
+inference_engine_classes = {
+  "mlx": "MLXDynamicShardInferenceEngine",
+  "tinygrad": "TinygradDynamicShardInferenceEngine",
+  "dummy": "DummyInferenceEngine",
+}
 
 def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
   if DEBUG >= 2:

+ 31 - 0
exo/models.py

@@ -83,6 +83,37 @@ model_cards = {
   "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
 }
 
+pretty_name = {
+  "llama-3.2-1b": "Llama 3.2 1B",
+  "llama-3.2-3b": "Llama 3.2 3B",
+  "llama-3.1-8b": "Llama 3.1 8B",
+  "llama-3.1-70b": "Llama 3.1 70B",
+  "llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
+  "llama-3.1-405b": "Llama 3.1 405B",
+  "llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
+  "gemma2-9b": "Gemma2 9B",
+  "gemma2-27b": "Gemma2 27B",
+  "nemotron-70b": "Nemotron 70B",
+  "nemotron-70b-bf16": "Nemotron 70B (BF16)",
+  "mistral-nemo": "Mistral Nemo",
+  "mistral-large": "Mistral Large",
+  "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
+  "deepseek-coder-v2.5": "Deepseek Coder V2.5",
+  "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
+  "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
+  "qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
+  "qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
+  "qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
+  "qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
+  "qwen-2.5-7b": "Qwen 2.5 7B",
+  "qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
+  "qwen-2.5-14b": "Qwen 2.5 14B",
+  "qwen-2.5-72b": "Qwen 2.5 72B",
+  "qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
+  "llama-3-8b": "Llama 3 8B",
+  "llama-3-70b": "Llama 3 70B",
+}
+
 def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
   return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
 

+ 3 - 30
exo/tinychat/index.html

@@ -29,36 +29,8 @@
     <div x-show="errorMessage" x-transition.opacity x-text="errorMessage" class="toast">
     </div>
 <div class="model-selector">
-<select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel">
-<option value="llama-3.2-1b">Llama 3.2 1B</option>
-<option value="llama-3.2-3b">Llama 3.2 3B</option>
-<option value="llama-3.1-8b">Llama 3.1 8B</option>
-<option value="llama-3.1-70b">Llama 3.1 70B</option>
-<option value="llama-3.1-70b-bf16">Llama 3.1 70B (BF16)</option>
-<option value="llama-3.1-405b">Llama 3.1 405B</option>
-<option value="llama-3.1-405b-8bit">Llama 3.1 405B (8-bit)</option>
-<option value="gemma2-9b">Gemma2 9B</option>
-<option value="gemma2-27b">Gemma2 27B</option>
-<option value="nemotron-70b">Nemotron 70B</option>
-<option value="nemotron-70b-bf16">Nemotron 70B (BF16)</option>
-<option value="mistral-nemo">Mistral Nemo</option>
-<option value="mistral-large">Mistral Large</option>
-<option value="deepseek-coder-v2-lite">Deepseek Coder V2 Lite</option>
-<option value="deepseek-coder-v2.5">Deepseek Coder V2.5</option>
-<option value="llava-1.5-7b-hf">LLaVa 1.5 7B (Vision Model)</option>
-<option value="qwen-2.5-coder-1.5b">Qwen 2.5 Coder 1.5B</option>
-<option value="qwen-2.5-coder-3b">Qwen 2.5 Coder 3B</option>
-<option value="qwen-2.5-coder-7b">Qwen 2.5 Coder 7B</option>
-<option value="qwen-2.5-coder-14b">Qwen 2.5 Coder 14B</option>
-<option value="qwen-2.5-coder-32b">Qwen 2.5 Coder 32B</option>
-<option value="qwen-2.5-7b">Qwen 2.5 7B</option>
-<option value="qwen-2.5-math-7b">Qwen 2.5 7B (Math)</option>
-<option value="qwen-2.5-14b">Qwen 2.5 14B</option>
-<option value="qwen-2.5-72b">Qwen 2.5 72B</option>
-<option value="qwen-2.5-math-72b">Qwen 2.5 72B (Math)</option>
-<option value="llama-3-8b">Llama 3 8B</option>
-<option value="llama-3-70b">Llama 3 70B</option>
-</select>
+  <select @change="if (cstate) cstate.selectedModel = $event.target.value" x-model="cstate.selectedModel" x-init="await populateSelector()" class='model-select'>
+  </select>
 </div>
 <div @popstate.window="
       if (home === 2) {
@@ -221,6 +193,7 @@
 <i class="fas fa-times"></i>
 </button>
 </div>
+<script src="await populateSelector()" defer></script>
 <textarea :disabled="generating" :placeholder="generating ? 'Generating...' : 'Say something'" @input="
             home = (home === 0) ? 1 : home
             if (cstate.messages.length === 0 &amp;&amp; $el.value === '') home = -1;

+ 23 - 0
exo/tinychat/index.js

@@ -72,6 +72,28 @@ document.addEventListener("alpine:init", () => {
       return `${s}s`;
     },
 
+    async populateSelector() {
+      const response = await fetch(`${this.endpoint}/modelpool`);
+      console.log("Populating Selector")
+      if(!response.ok) {
+        const errorResBody = await response.json();
+        if (errorResBody?.detail) {
+          throw new Error(`Failed to get model pool: ${errorResBody.detail}`);
+        } else {
+          throw new Error("Failed to get model pool: Unknown error");
+        }
+      }
+      sel = document.getElementById("model-select");
+      sel.empty();
+      response["model pool"].map((k, v) => {
+        let opt = document.createElement("option");
+        opt.value = k;
+        opt.innerHtml = v;
+        console.log(`Model: ${k} (${v})`)
+        sel.append(opt);
+      });
+    },
+
     async handleImageUpload(event) {
       const file = event.target.files[0];
       if (file) {
@@ -535,6 +557,7 @@ function createParser(onParse) {
     }
   }
 }
+
 const BOM = [239, 187, 191];
 function hasBom(buffer) {
   return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);