|
@@ -5,18 +5,24 @@ import json
|
|
import os
|
|
import os
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
-from typing import List, Literal, Union, Dict
|
|
|
|
|
|
+from typing import List, Literal, Union, Dict, Optional
|
|
from aiohttp import web
|
|
from aiohttp import web
|
|
import aiohttp_cors
|
|
import aiohttp_cors
|
|
import traceback
|
|
import traceback
|
|
import signal
|
|
import signal
|
|
from exo import DEBUG, VERSION
|
|
from exo import DEBUG, VERSION
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
from exo.download.download_progress import RepoProgressEvent
|
|
-from exo.helpers import PrefixDict, shutdown
|
|
|
|
|
|
+from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.inference.tokenizers import resolve_tokenizer
|
|
from exo.orchestration import Node
|
|
from exo.orchestration import Node
|
|
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
|
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
|
from typing import Callable, Optional
|
|
from typing import Callable, Optional
|
|
|
|
+from PIL import Image
|
|
|
|
+import numpy as np
|
|
|
|
+import base64
|
|
|
|
+from io import BytesIO
|
|
|
|
+import mlx.core as mx
|
|
|
|
+import tempfile
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
import shutil
|
|
import shutil
|
|
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
|
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
|
@@ -24,23 +30,28 @@ from exo.apputil import create_animation_mp4
|
|
from collections import defaultdict
|
|
from collections import defaultdict
|
|
|
|
|
|
class Message:
|
|
class Message:
|
|
- def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
|
|
|
|
|
+ def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
|
self.role = role
|
|
self.role = role
|
|
self.content = content
|
|
self.content = content
|
|
|
|
+ self.tools = tools
|
|
|
|
|
|
def to_dict(self):
|
|
def to_dict(self):
|
|
- return {"role": self.role, "content": self.content}
|
|
|
|
|
|
+ data = {"role": self.role, "content": self.content}
|
|
|
|
+ if self.tools:
|
|
|
|
+ data["tools"] = self.tools
|
|
|
|
+ return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatCompletionRequest:
|
|
class ChatCompletionRequest:
|
|
- def __init__(self, model: str, messages: List[Message], temperature: float):
|
|
|
|
|
|
+ def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
|
self.model = model
|
|
self.model = model
|
|
self.messages = messages
|
|
self.messages = messages
|
|
self.temperature = temperature
|
|
self.temperature = temperature
|
|
|
|
+ self.tools = tools
|
|
|
|
|
|
def to_dict(self):
|
|
def to_dict(self):
|
|
- return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
|
|
|
|
|
|
+ return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
|
|
|
|
|
|
|
|
|
|
def generate_completion(
|
|
def generate_completion(
|
|
@@ -120,20 +131,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
|
return remapped_messages
|
|
return remapped_messages
|
|
|
|
|
|
|
|
|
|
-def build_prompt(tokenizer, _messages: List[Message]):
|
|
|
|
|
|
+def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
|
messages = remap_messages(_messages)
|
|
messages = remap_messages(_messages)
|
|
- prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
|
|
|
|
- for message in messages:
|
|
|
|
- if not isinstance(message.content, list):
|
|
|
|
- continue
|
|
|
|
|
|
+ chat_template_args = {
|
|
|
|
+ "conversation": [m.to_dict() for m in messages],
|
|
|
|
+ "tokenize": False,
|
|
|
|
+ "add_generation_prompt": True
|
|
|
|
+ }
|
|
|
|
+ if tools: chat_template_args["tools"] = tools
|
|
|
|
|
|
|
|
+ prompt = tokenizer.apply_chat_template(**chat_template_args)
|
|
|
|
+ print(f"!!! Prompt: {prompt}")
|
|
return prompt
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
def parse_message(data: dict):
|
|
def parse_message(data: dict):
|
|
if "role" not in data or "content" not in data:
|
|
if "role" not in data or "content" not in data:
|
|
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
|
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
|
- return Message(data["role"], data["content"])
|
|
|
|
|
|
+ return Message(data["role"], data["content"], data.get("tools"))
|
|
|
|
|
|
|
|
|
|
def parse_chat_request(data: dict, default_model: str):
|
|
def parse_chat_request(data: dict, default_model: str):
|
|
@@ -141,6 +156,7 @@ def parse_chat_request(data: dict, default_model: str):
|
|
data.get("model", default_model),
|
|
data.get("model", default_model),
|
|
[parse_message(msg) for msg in data["messages"]],
|
|
[parse_message(msg) for msg in data["messages"]],
|
|
data.get("temperature", 0.0),
|
|
data.get("temperature", 0.0),
|
|
|
|
+ data.get("tools", None),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -151,7 +167,7 @@ class PromptSession:
|
|
self.prompt = prompt
|
|
self.prompt = prompt
|
|
|
|
|
|
class ChatGPTAPI:
|
|
class ChatGPTAPI:
|
|
- def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
|
|
|
|
|
+ def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
|
self.node = node
|
|
self.node = node
|
|
self.inference_engine_classname = inference_engine_classname
|
|
self.inference_engine_classname = inference_engine_classname
|
|
self.response_timeout = response_timeout
|
|
self.response_timeout = response_timeout
|
|
@@ -166,6 +182,7 @@ class ChatGPTAPI:
|
|
# Get the callback system and register our handler
|
|
# Get the callback system and register our handler
|
|
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
|
|
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
|
|
self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
|
|
self.token_callback.on_next(lambda _request_id, token, is_finished: asyncio.create_task(self.handle_token(_request_id, token, is_finished)))
|
|
|
|
+ self.system_prompt = system_prompt
|
|
|
|
|
|
cors = aiohttp_cors.setup(self.app)
|
|
cors = aiohttp_cors.setup(self.app)
|
|
cors_options = aiohttp_cors.ResourceOptions(
|
|
cors_options = aiohttp_cors.ResourceOptions(
|
|
@@ -180,6 +197,7 @@ class ChatGPTAPI:
|
|
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
|
|
|
+ cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
|
@@ -191,10 +209,12 @@ class ChatGPTAPI:
|
|
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
|
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
|
|
|
|
|
|
|
+
|
|
if "__compiled__" not in globals():
|
|
if "__compiled__" not in globals():
|
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
self.app.router.add_get("/", self.handle_root)
|
|
self.app.router.add_get("/", self.handle_root)
|
|
self.app.router.add_static("/", self.static_dir, name="static")
|
|
self.app.router.add_static("/", self.static_dir, name="static")
|
|
|
|
+ self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
|
|
|
|
|
|
self.app.middlewares.append(self.timeout_middleware)
|
|
self.app.middlewares.append(self.timeout_middleware)
|
|
self.app.middlewares.append(self.log_request)
|
|
self.app.middlewares.append(self.log_request)
|
|
@@ -241,7 +261,7 @@ class ChatGPTAPI:
|
|
)
|
|
)
|
|
await response.prepare(request)
|
|
await response.prepare(request)
|
|
|
|
|
|
- for model_name, pretty in pretty_name.items():
|
|
|
|
|
|
+ async def process_model(model_name, pretty):
|
|
if model_name in model_cards:
|
|
if model_name in model_cards:
|
|
model_info = model_cards[model_name]
|
|
model_info = model_cards[model_name]
|
|
|
|
|
|
@@ -269,6 +289,12 @@ class ChatGPTAPI:
|
|
|
|
|
|
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
|
|
|
|
|
|
+ # Process all models in parallel
|
|
|
|
+ await asyncio.gather(*[
|
|
|
|
+ process_model(model_name, pretty)
|
|
|
|
+ for model_name, pretty in pretty_name.items()
|
|
|
|
+ ])
|
|
|
|
+
|
|
await response.write(b"data: [DONE]\n\n")
|
|
await response.write(b"data: [DONE]\n\n")
|
|
return response
|
|
return response
|
|
|
|
|
|
@@ -281,7 +307,8 @@ class ChatGPTAPI:
|
|
)
|
|
)
|
|
|
|
|
|
async def handle_get_models(self, request):
|
|
async def handle_get_models(self, request):
|
|
- return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|
|
|
|
|
|
+ models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
|
|
|
+ return web.json_response({"object": "list", "data": models_list})
|
|
|
|
|
|
async def handle_post_chat_token_encode(self, request):
|
|
async def handle_post_chat_token_encode(self, request):
|
|
data = await request.json()
|
|
data = await request.json()
|
|
@@ -294,7 +321,7 @@ class ChatGPTAPI:
|
|
shard = build_base_shard(model, self.inference_engine_classname)
|
|
shard = build_base_shard(model, self.inference_engine_classname)
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
- prompt = build_prompt(tokenizer, messages)
|
|
|
|
|
|
+ prompt = build_prompt(tokenizer, messages, data.get("tools", None))
|
|
tokens = tokenizer.encode(prompt)
|
|
tokens = tokenizer.encode(prompt)
|
|
return web.json_response({
|
|
return web.json_response({
|
|
"length": len(prompt),
|
|
"length": len(prompt),
|
|
@@ -314,13 +341,13 @@ class ChatGPTAPI:
|
|
|
|
|
|
async def handle_post_chat_completions(self, request):
|
|
async def handle_post_chat_completions(self, request):
|
|
data = await request.json()
|
|
data = await request.json()
|
|
- if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
|
|
|
|
|
+ if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
|
|
stream = data.get("stream", False)
|
|
stream = data.get("stream", False)
|
|
chat_request = parse_chat_request(data, self.default_model)
|
|
chat_request = parse_chat_request(data, self.default_model)
|
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
|
|
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
|
|
chat_request.model = self.default_model
|
|
chat_request.model = self.default_model
|
|
if not chat_request.model or chat_request.model not in model_cards:
|
|
if not chat_request.model or chat_request.model not in model_cards:
|
|
- if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
|
|
|
|
|
+ if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
|
chat_request.model = self.default_model
|
|
chat_request.model = self.default_model
|
|
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
|
|
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
|
|
if not shard:
|
|
if not shard:
|
|
@@ -331,34 +358,26 @@ class ChatGPTAPI:
|
|
)
|
|
)
|
|
|
|
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
|
- if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
|
|
|
|
|
+ if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
|
|
|
|
+
|
|
|
|
+ # Add system prompt if set
|
|
|
|
+ if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
|
|
|
|
+ chat_request.messages.insert(0, Message("system", self.system_prompt))
|
|
|
|
|
|
- prompt = build_prompt(tokenizer, chat_request.messages)
|
|
|
|
|
|
+ prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
|
request_id = str(uuid.uuid4())
|
|
request_id = str(uuid.uuid4())
|
|
if self.on_chat_completion_request:
|
|
if self.on_chat_completion_request:
|
|
try:
|
|
try:
|
|
self.on_chat_completion_request(request_id, chat_request, prompt)
|
|
self.on_chat_completion_request(request_id, chat_request, prompt)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
- # request_id = None
|
|
|
|
- # match = self.prompts.find_longest_prefix(prompt)
|
|
|
|
- # if match and len(prompt) > len(match[1].prompt):
|
|
|
|
- # if DEBUG >= 2:
|
|
|
|
- # print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
|
|
|
|
- # request_id = match[1].request_id
|
|
|
|
- # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
|
|
|
|
- # # remove the matching prefix from the prompt
|
|
|
|
- # prompt = prompt[len(match[1].prompt):]
|
|
|
|
- # else:
|
|
|
|
- # request_id = str(uuid.uuid4())
|
|
|
|
- # self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
|
|
|
|
-
|
|
|
|
- if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
|
|
|
|
|
|
+
|
|
|
|
+ if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
|
|
|
|
|
|
try:
|
|
try:
|
|
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
|
|
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
|
|
|
|
|
|
- if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
|
|
|
|
|
|
+ if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
|
|
|
|
|
|
if stream:
|
|
if stream:
|
|
response = web.StreamResponse(
|
|
response = web.StreamResponse(
|
|
@@ -374,10 +393,12 @@ class ChatGPTAPI:
|
|
try:
|
|
try:
|
|
# Stream tokens while waiting for inference to complete
|
|
# Stream tokens while waiting for inference to complete
|
|
while True:
|
|
while True:
|
|
|
|
+ if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
|
|
token, is_finished = await asyncio.wait_for(
|
|
token, is_finished = await asyncio.wait_for(
|
|
self.token_queues[request_id].get(),
|
|
self.token_queues[request_id].get(),
|
|
timeout=self.response_timeout
|
|
timeout=self.response_timeout
|
|
)
|
|
)
|
|
|
|
+ if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {token=} {is_finished=}")
|
|
|
|
|
|
finish_reason = None
|
|
finish_reason = None
|
|
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
|
|
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") else getattr(tokenizer, "eos_token_id", None)
|
|
@@ -408,10 +429,13 @@ class ChatGPTAPI:
|
|
return response
|
|
return response
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
except asyncio.TimeoutError:
|
|
|
|
+ if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
|
|
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
|
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- if DEBUG >= 2: traceback.print_exc()
|
|
|
|
|
|
+ if DEBUG >= 2:
|
|
|
|
+ print(f"[ChatGPTAPI] Error processing prompt: {e}")
|
|
|
|
+ traceback.print_exc()
|
|
return web.json_response(
|
|
return web.json_response(
|
|
{"detail": f"Error processing prompt: {str(e)}"},
|
|
{"detail": f"Error processing prompt: {str(e)}"},
|
|
status=500
|
|
status=500
|
|
@@ -420,6 +444,7 @@ class ChatGPTAPI:
|
|
finally:
|
|
finally:
|
|
# Clean up the queue for this request
|
|
# Clean up the queue for this request
|
|
if request_id in self.token_queues:
|
|
if request_id in self.token_queues:
|
|
|
|
+ if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
|
|
del self.token_queues[request_id]
|
|
del self.token_queues[request_id]
|
|
else:
|
|
else:
|
|
tokens = []
|
|
tokens = []
|
|
@@ -441,6 +466,85 @@ class ChatGPTAPI:
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
|
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
|
|
|
|
|
|
|
+
|
|
|
|
+ async def handle_post_image_generations(self, request):
|
|
|
|
+ data = await request.json()
|
|
|
|
+
|
|
|
|
+ if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
|
|
|
+ stream = data.get("stream", False)
|
|
|
|
+ model = data.get("model", "")
|
|
|
|
+ prompt = data.get("prompt", "")
|
|
|
|
+ image_url = data.get("image_url", "")
|
|
|
|
+ if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
|
|
|
|
+ shard = build_base_shard(model, self.inference_engine_classname)
|
|
|
|
+ if DEBUG >= 2: print(f"shard: {shard}")
|
|
|
|
+ if not shard:
|
|
|
|
+ return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
|
|
|
+
|
|
|
|
+ request_id = str(uuid.uuid4())
|
|
|
|
+ callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
|
|
+ callback = self.node.on_token.register(callback_id)
|
|
|
|
+ try:
|
|
|
|
+ if image_url != "" and image_url != None:
|
|
|
|
+ img = self.base64_decode(image_url)
|
|
|
|
+ else:
|
|
|
|
+ img = None
|
|
|
|
+ await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
|
|
|
|
+ await response.prepare(request)
|
|
|
|
+
|
|
|
|
+ def get_progress_bar(current_step, total_steps, bar_length=50):
|
|
|
|
+ # Calculate the percentage of completion
|
|
|
|
+ percent = float(current_step) / total_steps
|
|
|
|
+ # Calculate the number of hashes to display
|
|
|
|
+ arrow = '-' * int(round(percent * bar_length) - 1) + '>'
|
|
|
|
+ spaces = ' ' * (bar_length - len(arrow))
|
|
|
|
+
|
|
|
|
+ # Create the progress bar string
|
|
|
|
+ progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
|
|
|
|
+ return progress_bar
|
|
|
|
+
|
|
|
|
+ async def stream_image(_request_id: str, result, is_finished: bool):
|
|
|
|
+ if isinstance(result, list):
|
|
|
|
+ await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
|
|
|
+
|
|
|
|
+ elif isinstance(result, np.ndarray):
|
|
|
|
+ im = Image.fromarray(np.array(result))
|
|
|
|
+ images_folder = get_exo_images_dir()
|
|
|
|
+ # Save the image to a file
|
|
|
|
+ image_filename = f"{_request_id}.png"
|
|
|
|
+ image_path = images_folder / image_filename
|
|
|
|
+ im.save(image_path)
|
|
|
|
+ image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
|
|
|
+ base_url = f"{request.scheme}://{request.host}"
|
|
|
|
+ # Construct the full URL correctly
|
|
|
|
+ full_image_url = base_url + str(image_url)
|
|
|
|
+
|
|
|
|
+ await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
|
|
|
+ if is_finished:
|
|
|
|
+ await response.write_eof()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ stream_task = None
|
|
|
|
+ def on_result(_request_id: str, result, is_finished: bool):
|
|
|
|
+ nonlocal stream_task
|
|
|
|
+ stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
|
|
|
+ return _request_id == request_id and is_finished
|
|
|
|
+
|
|
|
|
+ await callback.wait(on_result, timeout=self.response_timeout*10)
|
|
|
|
+
|
|
|
|
+ if stream_task:
|
|
|
|
+ # Wait for the stream task to complete before returning
|
|
|
|
+ await stream_task
|
|
|
|
+
|
|
|
|
+ return response
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ if DEBUG >= 2: traceback.print_exc()
|
|
|
|
+ return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
|
|
|
+
|
|
async def handle_delete_model(self, request):
|
|
async def handle_delete_model(self, request):
|
|
try:
|
|
try:
|
|
model_name = request.match_info.get('model_name')
|
|
model_name = request.match_info.get('model_name')
|
|
@@ -553,7 +657,7 @@ class ChatGPTAPI:
|
|
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
|
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
|
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
|
- asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
|
|
|
|
|
+ asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
|
|
|
|
|
return web.json_response({
|
|
return web.json_response({
|
|
"status": "success",
|
|
"status": "success",
|
|
@@ -585,3 +689,19 @@ class ChatGPTAPI:
|
|
await runner.setup()
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, host, port)
|
|
site = web.TCPSite(runner, host, port)
|
|
await site.start()
|
|
await site.start()
|
|
|
|
+
|
|
|
|
+ def base64_decode(self, base64_string):
|
|
|
|
+ #decode and reshape image
|
|
|
|
+ if base64_string.startswith('data:image'):
|
|
|
|
+ base64_string = base64_string.split(',')[1]
|
|
|
|
+ image_data = base64.b64decode(base64_string)
|
|
|
|
+ img = Image.open(BytesIO(image_data))
|
|
|
|
+ W, H = (dim - dim % 64 for dim in (img.width, img.height))
|
|
|
|
+ if W != img.width or H != img.height:
|
|
|
|
+ if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
|
|
|
+ img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
|
|
|
+ img = mx.array(np.array(img))
|
|
|
|
+ img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
|
|
|
+ img = img[None]
|
|
|
|
+ return img
|
|
|
|
+
|