|
|
@@ -2,6 +2,7 @@ import uuid
|
|
|
import time
|
|
|
import asyncio
|
|
|
import json
|
|
|
+import os
|
|
|
from pathlib import Path
|
|
|
from transformers import AutoTokenizer
|
|
|
from typing import List, Literal, Union, Dict
|
|
|
@@ -14,10 +15,12 @@ from exo.download.download_progress import RepoProgressEvent
|
|
|
from exo.helpers import PrefixDict, shutdown
|
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
|
from exo.orchestration import Node
|
|
|
-from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
|
|
|
-from exo.apputil import create_animation_mp4
|
|
|
+from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
|
|
from typing import Callable, Optional
|
|
|
-import tempfile
|
|
|
+from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
|
+import shutil
|
|
|
+from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
|
|
+from exo.apputil import create_animation_mp4
|
|
|
|
|
|
class Message:
|
|
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
|
@@ -175,7 +178,11 @@ class ChatGPTAPI:
|
|
|
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
|
|
|
+ cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
|
|
|
+ cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
|
|
|
+ cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
|
|
|
+ cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
|
|
|
|
|
if "__compiled__" not in globals():
|
|
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
|
@@ -215,22 +222,79 @@ class ChatGPTAPI:
|
|
|
return web.json_response({"status": "ok"})
|
|
|
|
|
|
async def handle_model_support(self, request):
|
|
|
- return web.json_response({
|
|
|
- "model pool": {
|
|
|
- model_name: pretty_name.get(model_name, model_name)
|
|
|
- for model_name in get_supported_models(self.node.topology_inference_engines_pool)
|
|
|
- }
|
|
|
- })
|
|
|
+ try:
|
|
|
+ response = web.StreamResponse(
|
|
|
+ status=200,
|
|
|
+ reason='OK',
|
|
|
+ headers={
|
|
|
+ 'Content-Type': 'text/event-stream',
|
|
|
+ 'Cache-Control': 'no-cache',
|
|
|
+ 'Connection': 'keep-alive',
|
|
|
+ }
|
|
|
+ )
|
|
|
+ await response.prepare(request)
|
|
|
+
|
|
|
+ for model_name, pretty in pretty_name.items():
|
|
|
+ if model_name in model_cards:
|
|
|
+ model_info = model_cards[model_name]
|
|
|
+
|
|
|
+ if self.inference_engine_classname in model_info.get("repo", {}):
|
|
|
+ shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
+ if shard:
|
|
|
+ downloader = HFShardDownloader(quick_check=True)
|
|
|
+ downloader.current_shard = shard
|
|
|
+ downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
+ status = await downloader.get_shard_download_status()
|
|
|
+
|
|
|
+ download_percentage = status.get("overall") if status else None
|
|
|
+ total_size = status.get("total_size") if status else None
|
|
|
+ total_downloaded = status.get("total_downloaded") if status else False
|
|
|
+
|
|
|
+ model_data = {
|
|
|
+ model_name: {
|
|
|
+ "name": pretty,
|
|
|
+ "downloaded": download_percentage == 100 if download_percentage is not None else False,
|
|
|
+ "download_percentage": download_percentage,
|
|
|
+ "total_size": total_size,
|
|
|
+ "total_downloaded": total_downloaded
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
|
+
|
|
|
+ await response.write(b"data: [DONE]\n\n")
|
|
|
+ return response
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error in handle_model_support: {str(e)}")
|
|
|
+ traceback.print_exc()
|
|
|
+ return web.json_response(
|
|
|
+ {"detail": f"Server error: {str(e)}"},
|
|
|
+ status=500
|
|
|
+ )
|
|
|
|
|
|
async def handle_get_models(self, request):
|
|
|
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|
|
|
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
|
data = await request.json()
|
|
|
- shard = build_base_shard(self.default_model, self.inference_engine_classname)
|
|
|
+ model = data.get("model", self.default_model)
|
|
|
+ if model and model.startswith("gpt-"): # Handle gpt- model requests
|
|
|
+ model = self.default_model
|
|
|
+ if not model or model not in model_cards:
|
|
|
+ if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
|
|
+ model = self.default_model
|
|
|
+ shard = build_base_shard(model, self.inference_engine_classname)
|
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
|
- return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})
|
|
|
+ prompt = build_prompt(tokenizer, messages)
|
|
|
+ tokens = tokenizer.encode(prompt)
|
|
|
+ return web.json_response({
|
|
|
+ "length": len(prompt),
|
|
|
+ "num_tokens": len(tokens),
|
|
|
+ "encoded_tokens": tokens,
|
|
|
+ "encoded_prompt": prompt,
|
|
|
+ })
|
|
|
|
|
|
async def handle_get_download_progress(self, request):
|
|
|
progress_data = {}
|
|
|
@@ -371,6 +435,71 @@ class ChatGPTAPI:
|
|
|
deregistered_callback = self.node.on_token.deregister(callback_id)
|
|
|
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
|
|
|
|
|
|
+ async def handle_delete_model(self, request):
|
|
|
+ try:
|
|
|
+ model_name = request.match_info.get('model_name')
|
|
|
+ if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
|
|
|
+
|
|
|
+ if not model_name or model_name not in model_cards:
|
|
|
+ return web.json_response(
|
|
|
+ {"detail": f"Invalid model name: {model_name}"},
|
|
|
+ status=400
|
|
|
+ )
|
|
|
+
|
|
|
+ shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
+ if not shard:
|
|
|
+ return web.json_response(
|
|
|
+ {"detail": "Could not build shard for model"},
|
|
|
+ status=400
|
|
|
+ )
|
|
|
+
|
|
|
+ repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
+ if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
|
|
|
+
|
|
|
+ # Get the HF cache directory using the helper function
|
|
|
+ hf_home = get_hf_home()
|
|
|
+ cache_dir = get_repo_root(repo_id)
|
|
|
+
|
|
|
+ if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
|
|
|
+
|
|
|
+ if os.path.exists(cache_dir):
|
|
|
+ if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
|
|
|
+ try:
|
|
|
+ shutil.rmtree(cache_dir)
|
|
|
+ return web.json_response({
|
|
|
+ "status": "success",
|
|
|
+ "message": f"Model {model_name} deleted successfully",
|
|
|
+ "path": str(cache_dir)
|
|
|
+ })
|
|
|
+ except Exception as e:
|
|
|
+ return web.json_response({
|
|
|
+ "detail": f"Failed to delete model files: {str(e)}"
|
|
|
+ }, status=500)
|
|
|
+ else:
|
|
|
+ return web.json_response({
|
|
|
+ "detail": f"Model files not found at {cache_dir}"
|
|
|
+ }, status=404)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error in handle_delete_model: {str(e)}")
|
|
|
+ traceback.print_exc()
|
|
|
+ return web.json_response({
|
|
|
+ "detail": f"Server error: {str(e)}"
|
|
|
+ }, status=500)
|
|
|
+
|
|
|
+ async def handle_get_initial_models(self, request):
|
|
|
+ model_data = {}
|
|
|
+ for model_name, pretty in pretty_name.items():
|
|
|
+ model_data[model_name] = {
|
|
|
+ "name": pretty,
|
|
|
+ "downloaded": None, # Initially unknown
|
|
|
+ "download_percentage": None, # Change from 0 to null
|
|
|
+ "total_size": None,
|
|
|
+ "total_downloaded": None,
|
|
|
+ "loading": True # Add loading state
|
|
|
+ }
|
|
|
+ return web.json_response(model_data)
|
|
|
+
|
|
|
async def handle_create_animation(self, request):
|
|
|
try:
|
|
|
data = await request.json()
|
|
|
@@ -410,6 +539,38 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
|
|
|
|
+ async def handle_post_download(self, request):
|
|
|
+ try:
|
|
|
+ data = await request.json()
|
|
|
+ model_name = data.get("model")
|
|
|
+ if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
|
|
|
+ if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
|
|
+ shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
+ if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
|
|
+ asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
|
|
+
|
|
|
+ return web.json_response({
|
|
|
+ "status": "success",
|
|
|
+ "message": f"Download started for model: {model_name}"
|
|
|
+ })
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 2: traceback.print_exc()
|
|
|
+ return web.json_response({"error": str(e)}, status=500)
|
|
|
+
|
|
|
+ async def handle_get_topology(self, request):
|
|
|
+ try:
|
|
|
+ topology = self.node.current_topology
|
|
|
+ if topology:
|
|
|
+ return web.json_response(topology.to_json())
|
|
|
+ else:
|
|
|
+ return web.json_response({})
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 2: traceback.print_exc()
|
|
|
+ return web.json_response(
|
|
|
+ {"detail": f"Error getting topology: {str(e)}"},
|
|
|
+ status=500
|
|
|
+ )
|
|
|
+
|
|
|
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
|
|
runner = web.AppRunner(self.app)
|
|
|
await runner.setup()
|