|
@@ -21,7 +21,13 @@ from PIL import Image
|
|
|
import numpy as np
|
|
|
import base64
|
|
|
from io import BytesIO
|
|
|
-import mlx.core as mx
|
|
|
+import platform
|
|
|
+
|
|
|
+if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
|
|
+ import mlx.core as mx
|
|
|
+else:
|
|
|
+ import numpy as mx
|
|
|
+
|
|
|
import tempfile
|
|
|
from exo.download.hf.hf_shard_download import HFShardDownloader
|
|
|
import shutil
|
|
@@ -29,6 +35,7 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
|
|
from exo.apputil import create_animation_mp4
|
|
|
from collections import defaultdict
|
|
|
|
|
|
+
|
|
|
class Message:
|
|
|
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
|
|
self.role = role
|
|
@@ -42,7 +49,6 @@ class Message:
|
|
|
return data
|
|
|
|
|
|
|
|
|
-
|
|
|
class ChatCompletionRequest:
|
|
|
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
|
|
self.model = model
|
|
@@ -133,16 +139,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
|
|
|
|
|
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
|
|
messages = remap_messages(_messages)
|
|
|
- chat_template_args = {
|
|
|
- "conversation": [m.to_dict() for m in messages],
|
|
|
- "tokenize": False,
|
|
|
- "add_generation_prompt": True
|
|
|
- }
|
|
|
- if tools: chat_template_args["tools"] = tools
|
|
|
-
|
|
|
- prompt = tokenizer.apply_chat_template(**chat_template_args)
|
|
|
- print(f"!!! Prompt: {prompt}")
|
|
|
- return prompt
|
|
|
+ chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
|
|
|
+ if tools:
|
|
|
+ chat_template_args["tools"] = tools
|
|
|
+
|
|
|
+ try:
|
|
|
+ prompt = tokenizer.apply_chat_template(**chat_template_args)
|
|
|
+ if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
|
|
|
+ return prompt
|
|
|
+ except UnicodeEncodeError:
|
|
|
+ # Handle Unicode encoding by ensuring everything is UTF-8
|
|
|
+ chat_template_args["conversation"] = [
|
|
|
+ {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
|
|
|
+ for k, v in m.to_dict().items()}
|
|
|
+ for m in messages
|
|
|
+ ]
|
|
|
+ prompt = tokenizer.apply_chat_template(**chat_template_args)
|
|
|
+ if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
|
|
|
+ return prompt
|
|
|
|
|
|
|
|
|
def parse_message(data: dict):
|
|
@@ -166,8 +180,17 @@ class PromptSession:
|
|
|
self.timestamp = timestamp
|
|
|
self.prompt = prompt
|
|
|
|
|
|
+
|
|
|
class ChatGPTAPI:
|
|
|
- def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ node: Node,
|
|
|
+ inference_engine_classname: str,
|
|
|
+ response_timeout: int = 90,
|
|
|
+ on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
|
|
|
+ default_model: Optional[str] = None,
|
|
|
+ system_prompt: Optional[str] = None
|
|
|
+ ):
|
|
|
self.node = node
|
|
|
self.inference_engine_classname = inference_engine_classname
|
|
|
self.response_timeout = response_timeout
|
|
@@ -209,18 +232,22 @@ class ChatGPTAPI:
|
|
|
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
|
|
|
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
|
|
|
|
|
-
|
|
|
+ # Add static routes
|
|
|
if "__compiled__" not in globals():
|
|
|
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
|
|
self.app.router.add_get("/", self.handle_root)
|
|
|
self.app.router.add_static("/", self.static_dir, name="static")
|
|
|
- self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
|
|
|
+
|
|
|
+ # Always add images route, regardless of compilation status
|
|
|
+ self.images_dir = get_exo_images_dir()
|
|
|
+ self.images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ self.app.router.add_static('/images/', self.images_dir, name='static_images')
|
|
|
|
|
|
self.app.middlewares.append(self.timeout_middleware)
|
|
|
self.app.middlewares.append(self.log_request)
|
|
|
|
|
|
async def handle_quit(self, request):
|
|
|
- if DEBUG>=1: print("Received quit signal")
|
|
|
+ if DEBUG >= 1: print("Received quit signal")
|
|
|
response = web.json_response({"detail": "Quit signal received"}, status=200)
|
|
|
await response.prepare(request)
|
|
|
await response.write_eof()
|
|
@@ -250,61 +277,48 @@ class ChatGPTAPI:
|
|
|
|
|
|
async def handle_model_support(self, request):
|
|
|
try:
|
|
|
- response = web.StreamResponse(
|
|
|
- status=200,
|
|
|
- reason='OK',
|
|
|
- headers={
|
|
|
- 'Content-Type': 'text/event-stream',
|
|
|
- 'Cache-Control': 'no-cache',
|
|
|
- 'Connection': 'keep-alive',
|
|
|
- }
|
|
|
- )
|
|
|
- await response.prepare(request)
|
|
|
+ response = web.StreamResponse(status=200, reason='OK', headers={
|
|
|
+ 'Content-Type': 'text/event-stream',
|
|
|
+ 'Cache-Control': 'no-cache',
|
|
|
+ 'Connection': 'keep-alive',
|
|
|
+ })
|
|
|
+ await response.prepare(request)
|
|
|
+
|
|
|
+ async def process_model(model_name, pretty):
|
|
|
+ if model_name in model_cards:
|
|
|
+ model_info = model_cards[model_name]
|
|
|
+
|
|
|
+ if self.inference_engine_classname in model_info.get("repo", {}):
|
|
|
+ shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
+ if shard:
|
|
|
+ downloader = HFShardDownloader(quick_check=True)
|
|
|
+ downloader.current_shard = shard
|
|
|
+ downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
+ status = await downloader.get_shard_download_status()
|
|
|
+
|
|
|
+ download_percentage = status.get("overall") if status else None
|
|
|
+ total_size = status.get("total_size") if status else None
|
|
|
+ total_downloaded = status.get("total_downloaded") if status else False
|
|
|
|
|
|
- async def process_model(model_name, pretty):
|
|
|
- if model_name in model_cards:
|
|
|
- model_info = model_cards[model_name]
|
|
|
-
|
|
|
- if self.inference_engine_classname in model_info.get("repo", {}):
|
|
|
- shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
- if shard:
|
|
|
- downloader = HFShardDownloader(quick_check=True)
|
|
|
- downloader.current_shard = shard
|
|
|
- downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
- status = await downloader.get_shard_download_status()
|
|
|
-
|
|
|
- download_percentage = status.get("overall") if status else None
|
|
|
- total_size = status.get("total_size") if status else None
|
|
|
- total_downloaded = status.get("total_downloaded") if status else False
|
|
|
-
|
|
|
- model_data = {
|
|
|
- model_name: {
|
|
|
- "name": pretty,
|
|
|
- "downloaded": download_percentage == 100 if download_percentage is not None else False,
|
|
|
- "download_percentage": download_percentage,
|
|
|
- "total_size": total_size,
|
|
|
- "total_downloaded": total_downloaded
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
|
-
|
|
|
- # Process all models in parallel
|
|
|
- await asyncio.gather(*[
|
|
|
- process_model(model_name, pretty)
|
|
|
- for model_name, pretty in pretty_name.items()
|
|
|
- ])
|
|
|
-
|
|
|
- await response.write(b"data: [DONE]\n\n")
|
|
|
- return response
|
|
|
+ model_data = {
|
|
|
+ model_name: {
|
|
|
+ "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
|
|
|
+ "total_downloaded": total_downloaded
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
|
|
+
|
|
|
+ # Process all models in parallel
|
|
|
+ await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
|
|
|
+
|
|
|
+ await response.write(b"data: [DONE]\n\n")
|
|
|
+ return response
|
|
|
|
|
|
except Exception as e:
|
|
|
- print(f"Error in handle_model_support: {str(e)}")
|
|
|
- traceback.print_exc()
|
|
|
- return web.json_response(
|
|
|
- {"detail": f"Server error: {str(e)}"},
|
|
|
- status=500
|
|
|
- )
|
|
|
+ print(f"Error in handle_model_support: {str(e)}")
|
|
|
+ traceback.print_exc()
|
|
|
+ return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
|
|
|
|
|
async def handle_get_models(self, request):
|
|
|
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
|
@@ -466,7 +480,6 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
|
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
|
|
|
|
|
-
|
|
|
async def handle_post_image_generations(self, request):
|
|
|
data = await request.json()
|
|
|
|
|
@@ -479,7 +492,7 @@ class ChatGPTAPI:
|
|
|
shard = build_base_shard(model, self.inference_engine_classname)
|
|
|
if DEBUG >= 2: print(f"shard: {shard}")
|
|
|
if not shard:
|
|
|
- return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
|
|
+ return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
|
|
|
|
|
request_id = str(uuid.uuid4())
|
|
|
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
|
@@ -491,77 +504,85 @@ class ChatGPTAPI:
|
|
|
img = None
|
|
|
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
|
|
|
|
|
|
-
|
|
|
- response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
|
|
|
+ response = web.StreamResponse(status=200, reason='OK', headers={
|
|
|
+ 'Content-Type': 'application/octet-stream',
|
|
|
+ "Cache-Control": "no-cache",
|
|
|
+ })
|
|
|
await response.prepare(request)
|
|
|
|
|
|
def get_progress_bar(current_step, total_steps, bar_length=50):
|
|
|
# Calculate the percentage of completion
|
|
|
- percent = float(current_step) / total_steps
|
|
|
+ percent = float(current_step)/total_steps
|
|
|
# Calculate the number of hashes to display
|
|
|
- arrow = '-' * int(round(percent * bar_length) - 1) + '>'
|
|
|
- spaces = ' ' * (bar_length - len(arrow))
|
|
|
-
|
|
|
+ arrow = '-'*int(round(percent*bar_length) - 1) + '>'
|
|
|
+ spaces = ' '*(bar_length - len(arrow))
|
|
|
+
|
|
|
# Create the progress bar string
|
|
|
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
|
|
|
return progress_bar
|
|
|
|
|
|
async def stream_image(_request_id: str, result, is_finished: bool):
|
|
|
- if isinstance(result, list):
|
|
|
- await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
|
|
+ if isinstance(result, list):
|
|
|
+ await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
|
|
|
|
|
- elif isinstance(result, np.ndarray):
|
|
|
+ elif isinstance(result, np.ndarray):
|
|
|
+ try:
|
|
|
im = Image.fromarray(np.array(result))
|
|
|
- images_folder = get_exo_images_dir()
|
|
|
# Save the image to a file
|
|
|
image_filename = f"{_request_id}.png"
|
|
|
- image_path = images_folder / image_filename
|
|
|
+ image_path = self.images_dir/image_filename
|
|
|
im.save(image_path)
|
|
|
- image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
|
|
- base_url = f"{request.scheme}://{request.host}"
|
|
|
- # Construct the full URL correctly
|
|
|
- full_image_url = base_url + str(image_url)
|
|
|
|
|
|
- await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
|
|
+ # Get URL for the saved image
|
|
|
+ try:
|
|
|
+ image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
|
|
+ base_url = f"{request.scheme}://{request.host}"
|
|
|
+ full_image_url = base_url + str(image_url)
|
|
|
+
|
|
|
+ await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
|
|
+ except KeyError as e:
|
|
|
+ if DEBUG >= 2: print(f"Error getting image URL: {e}")
|
|
|
+ # Fallback to direct file path if URL generation fails
|
|
|
+ await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
|
|
+
|
|
|
if is_finished:
|
|
|
await response.write_eof()
|
|
|
-
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ if DEBUG >= 2: print(f"Error processing image: {e}")
|
|
|
+ if DEBUG >= 2: traceback.print_exc()
|
|
|
+ await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
|
|
|
|
|
|
stream_task = None
|
|
|
+
|
|
|
def on_result(_request_id: str, result, is_finished: bool):
|
|
|
- nonlocal stream_task
|
|
|
- stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
|
|
- return _request_id == request_id and is_finished
|
|
|
+ nonlocal stream_task
|
|
|
+ stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
|
|
+ return _request_id == request_id and is_finished
|
|
|
|
|
|
await callback.wait(on_result, timeout=self.response_timeout*10)
|
|
|
-
|
|
|
+
|
|
|
if stream_task:
|
|
|
- # Wait for the stream task to complete before returning
|
|
|
- await stream_task
|
|
|
+ # Wait for the stream task to complete before returning
|
|
|
+ await stream_task
|
|
|
|
|
|
return response
|
|
|
|
|
|
except Exception as e:
|
|
|
- if DEBUG >= 2: traceback.print_exc()
|
|
|
- return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
|
|
-
|
|
|
+ if DEBUG >= 2: traceback.print_exc()
|
|
|
+ return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
|
|
+
|
|
|
async def handle_delete_model(self, request):
|
|
|
try:
|
|
|
model_name = request.match_info.get('model_name')
|
|
|
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
|
|
|
|
|
|
if not model_name or model_name not in model_cards:
|
|
|
- return web.json_response(
|
|
|
- {"detail": f"Invalid model name: {model_name}"},
|
|
|
- status=400
|
|
|
- )
|
|
|
+ return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
|
|
|
|
|
|
shard = build_base_shard(model_name, self.inference_engine_classname)
|
|
|
if not shard:
|
|
|
- return web.json_response(
|
|
|
- {"detail": "Could not build shard for model"},
|
|
|
- status=400
|
|
|
- )
|
|
|
+ return web.json_response({"detail": "Could not build shard for model"}, status=400)
|
|
|
|
|
|
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
|
|
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
|
|
@@ -576,38 +597,28 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
|
|
|
try:
|
|
|
shutil.rmtree(cache_dir)
|
|
|
- return web.json_response({
|
|
|
- "status": "success",
|
|
|
- "message": f"Model {model_name} deleted successfully",
|
|
|
- "path": str(cache_dir)
|
|
|
- })
|
|
|
+ return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
|
|
|
except Exception as e:
|
|
|
- return web.json_response({
|
|
|
- "detail": f"Failed to delete model files: {str(e)}"
|
|
|
- }, status=500)
|
|
|
+ return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
|
|
|
else:
|
|
|
- return web.json_response({
|
|
|
- "detail": f"Model files not found at {cache_dir}"
|
|
|
- }, status=404)
|
|
|
+ return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
|
|
|
|
|
|
except Exception as e:
|
|
|
- print(f"Error in handle_delete_model: {str(e)}")
|
|
|
- traceback.print_exc()
|
|
|
- return web.json_response({
|
|
|
- "detail": f"Server error: {str(e)}"
|
|
|
- }, status=500)
|
|
|
+ print(f"Error in handle_delete_model: {str(e)}")
|
|
|
+ traceback.print_exc()
|
|
|
+ return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
|
|
|
|
|
async def handle_get_initial_models(self, request):
|
|
|
model_data = {}
|
|
|
for model_name, pretty in pretty_name.items():
|
|
|
- model_data[model_name] = {
|
|
|
- "name": pretty,
|
|
|
- "downloaded": None, # Initially unknown
|
|
|
- "download_percentage": None, # Change from 0 to null
|
|
|
- "total_size": None,
|
|
|
- "total_downloaded": None,
|
|
|
- "loading": True # Add loading state
|
|
|
- }
|
|
|
+ model_data[model_name] = {
|
|
|
+ "name": pretty,
|
|
|
+ "downloaded": None, # Initially unknown
|
|
|
+ "download_percentage": None, # Change from 0 to null
|
|
|
+ "total_size": None,
|
|
|
+ "total_downloaded": None,
|
|
|
+ "loading": True # Add loading state
|
|
|
+ }
|
|
|
return web.json_response(model_data)
|
|
|
|
|
|
async def handle_create_animation(self, request):
|
|
@@ -633,17 +644,9 @@ class ChatGPTAPI:
|
|
|
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
|
|
|
|
|
|
# Create the animation
|
|
|
- create_animation_mp4(
|
|
|
- replacement_image_path,
|
|
|
- output_path,
|
|
|
- device_name,
|
|
|
- prompt_text
|
|
|
- )
|
|
|
+ create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
|
|
|
|
|
|
- return web.json_response({
|
|
|
- "status": "success",
|
|
|
- "output_path": output_path
|
|
|
- })
|
|
|
+ return web.json_response({"status": "success", "output_path": output_path})
|
|
|
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
@@ -659,10 +662,7 @@ class ChatGPTAPI:
|
|
|
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
|
|
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
|
|
|
|
|
- return web.json_response({
|
|
|
- "status": "success",
|
|
|
- "message": f"Download started for model: {model_name}"
|
|
|
- })
|
|
|
+ return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
@@ -676,10 +676,7 @@ class ChatGPTAPI:
|
|
|
return web.json_response({})
|
|
|
except Exception as e:
|
|
|
if DEBUG >= 2: traceback.print_exc()
|
|
|
- return web.json_response(
|
|
|
- {"detail": f"Error getting topology: {str(e)}"},
|
|
|
- status=500
|
|
|
- )
|
|
|
+ return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
|
|
|
|
|
|
async def handle_token(self, request_id: str, token: int, is_finished: bool):
|
|
|
await self.token_queues[request_id].put((token, is_finished))
|
|
@@ -693,15 +690,14 @@ class ChatGPTAPI:
|
|
|
def base64_decode(self, base64_string):
|
|
|
#decode and reshape image
|
|
|
if base64_string.startswith('data:image'):
|
|
|
- base64_string = base64_string.split(',')[1]
|
|
|
+ base64_string = base64_string.split(',')[1]
|
|
|
image_data = base64.b64decode(base64_string)
|
|
|
img = Image.open(BytesIO(image_data))
|
|
|
- W, H = (dim - dim % 64 for dim in (img.width, img.height))
|
|
|
+ W, H = (dim - dim%64 for dim in (img.width, img.height))
|
|
|
if W != img.width or H != img.height:
|
|
|
- if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
|
|
- img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
|
|
+ if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
|
|
+ img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
|
|
img = mx.array(np.array(img))
|
|
|
- img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
|
|
+ img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
|
|
|
img = img[None]
|
|
|
return img
|
|
|
-
|